Move get detail shape to backend anfalgo.

This commit is contained in:
ZPaC 2023-02-21 11:44:59 +08:00
parent 13cfed8045
commit 6cfd711064
73 changed files with 182 additions and 183 deletions

View File

@ -633,10 +633,10 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
if (is_input) { if (is_input) {
auto node_input = common::AnfAlgo::GetInputNode(node, 0); auto node_input = common::AnfAlgo::GetInputNode(node, 0);
(void)new_cast_inputs.emplace_back(node_input); (void)new_cast_inputs.emplace_back(node_input);
shape = common::AnfAlgo::GetOutputDetailShape(node_input, 0); shape = AnfAlgo::GetOutputDetailShape(node_input, 0);
} else { } else {
(void)new_cast_inputs.emplace_back(node); (void)new_cast_inputs.emplace_back(node);
shape = common::AnfAlgo::GetOutputDetailShape(node, 0); shape = AnfAlgo::GetOutputDetailShape(node, 0);
} }
CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node}); CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node});
new_cast->set_scope(node->scope()); new_cast->set_scope(node->scope());

View File

@ -549,7 +549,7 @@ bool AnfRuntimeAlgorithm::IsRealSquenceOutput(const AnfNodePtr &node) {
std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx, std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx,
const std::string &format) { const std::string &format) {
auto output_shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx); auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
std::vector<int64_t> infer_shape; std::vector<int64_t> infer_shape;
if (output_shape->isa<abstract::Shape>()) { if (output_shape->isa<abstract::Shape>()) {
auto shape_ptr = output_shape->cast<abstract::ShapePtr>(); auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
@ -589,7 +589,7 @@ ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, si
std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx, std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx,
const std::string &format) { const std::string &format) {
auto output_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx); auto output_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx);
std::vector<int64_t> infer_shape; std::vector<int64_t> infer_shape;
if (output_shape->isa<abstract::Shape>()) { if (output_shape->isa<abstract::Shape>()) {
auto shape_ptr = output_shape->cast<abstract::ShapePtr>(); auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
@ -1733,7 +1733,49 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr
return {AnfAlgo::GetAbstractObjectType(node->abstract())}; return {AnfAlgo::GetAbstractObjectType(node->abstract())};
} }
std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node) { abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) {
if (output_idx == 0) {
return base_shape;
}
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
<< "]." << trace::DumpSourceLines(node);
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (IsRealSquenceOutput(node)) {
return tuple_shape;
}
if (output_idx >= tuple_shape->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
<< " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
}
auto b_shp = (*tuple_shape)[output_idx];
if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
return b_shp;
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
<< "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
}
} else if (base_shape->isa<abstract::NoShape>()) {
return base_shape;
} else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
return common::AnfAlgo::GetDynamicSequenceShape(node, output_idx);
}
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
}
abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
}
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
std::vector<TypeId> outputs; std::vector<TypeId> outputs;
auto out_nums = AnfAlgo::GetOutputElementNum(node); auto out_nums = AnfAlgo::GetOutputElementNum(node);
@ -1746,7 +1788,7 @@ std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node)
// if input node is MakeTuple, find the PrevNodeNum recursively; // if input node is MakeTuple, find the PrevNodeNum recursively;
// The monad node in the end is not included in the num; // The monad node in the end is not included in the num;
size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) { size_t AnfRuntimeAlgorithm::GetInputElementNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
@ -1772,7 +1814,7 @@ size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) {
return element_num; return element_num;
} }
void AnfAlgo::SetDynamicAttrToPrim(const PrimitivePtr &prim) { void AnfRuntimeAlgorithm::SetDynamicAttrToPrim(const PrimitivePtr &prim) {
prim->AddAttr(kAttrMutableKernel, MakeValue(true)); prim->AddAttr(kAttrMutableKernel, MakeValue(true));
prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true)); prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true));
prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true)); prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true));

View File

@ -225,6 +225,10 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
static size_t GetInputElementNum(const AnfNodePtr &node); static size_t GetInputElementNum(const AnfNodePtr &node);
static bool IsRealSquenceOutput(const AnfNodePtr &node); static bool IsRealSquenceOutput(const AnfNodePtr &node);
static void SetDynamicAttrToPrim(const PrimitivePtr &prim); static void SetDynamicAttrToPrim(const PrimitivePtr &prim);
// Get output detail shape. These interfaces should take TUPLE output into consideration.
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -139,9 +139,7 @@ class COMMON_EXPORT AnfAlgo {
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes, static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
AnfNode *node, bool disable_dynamic_len = false); AnfNode *node, bool disable_dynamic_len = false);
static void SetScalarTupleOutputInferType(const std::vector<TypeId> &types, const AnfNodePtr &node); static void SetScalarTupleOutputInferType(const std::vector<TypeId> &types, const AnfNodePtr &node);
// get and set output shape ptr // set output shape ptr
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types, static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node); const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
@ -296,6 +294,9 @@ class COMMON_EXPORT AnfAlgo {
static bool HasTupleInput(const CNodePtr &node); static bool HasTupleInput(const CNodePtr &node);
static bool HasDynamicTupleInput(const CNodePtr &node); static bool HasDynamicTupleInput(const CNodePtr &node);
static bool IsReduceOp(const std::string &op_name); static bool IsReduceOp(const std::string &op_name);
// Get the element shape of dynamic sequence shape.
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
}; };
} // namespace common } // namespace common
} // namespace mindspore } // namespace mindspore

View File

@ -1416,8 +1416,8 @@ class AscendAutoMonadConverter {
std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
source}; source};
auto cast_node = kernel_graph_->NewCNode(cast_inputs); auto cast_node = kernel_graph_->NewCNode(cast_inputs);
auto origin_shape = common::AnfAlgo::GetOutputDetailShape(source, kFirstOutput); auto origin_shape = AnfAlgo::GetOutputDetailShape(source, kFirstOutput);
auto shape = common::AnfAlgo::GetOutputDetailShape(target, kFirstOutput); auto shape = AnfAlgo::GetOutputDetailShape(target, kFirstOutput);
if (!common::IsEqual(origin_shape, shape)) { if (!common::IsEqual(origin_shape, shape)) {
MS_LOG(EXCEPTION) << "Assign: " << target->DebugString() << " and " << source->DebugString() MS_LOG(EXCEPTION) << "Assign: " << target->DebugString() << " and " << source->DebugString()
<< " has different shape, source shape: " << origin_shape->ToString() << " has different shape, source shape: " << origin_shape->ToString()

View File

@ -97,7 +97,7 @@ std::vector<int64_t> TbeJsonUtils::GetInputDeviceShapeForTbeBuild(const AnfNodeP
std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) { std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
std::vector<int64_t> shape; std::vector<int64_t> shape;
auto out_shape = common::AnfAlgo::GetOutputDetailShape(anf_node, real_idx); auto out_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_idx);
MS_EXCEPTION_IF_NULL(out_shape); MS_EXCEPTION_IF_NULL(out_shape);
if (out_shape->isa<abstract::Shape>()) { if (out_shape->isa<abstract::Shape>()) {
auto shape_ptr = out_shape->cast<abstract::ShapePtr>(); auto shape_ptr = out_shape->cast<abstract::ShapePtr>();

View File

@ -71,8 +71,8 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output, std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) { const std::string &format) {
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index) auto shape =
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index); is_output ? AnfAlgo::GetOutputDetailShape(node, index) : AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
std::vector<int64_t> infer_shape; std::vector<int64_t> infer_shape;
if (shape->isa<abstract::Shape>()) { if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>(); auto shape_ptr = shape->cast<abstract::ShapePtr>();

View File

@ -238,7 +238,7 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
<< input_format << " and dst format " << dst_format; << input_format << " and dst format " << dst_format;
} }
std::string spec_format = input_format == kOpFormat_DEFAULT ? dst_format : input_format; std::string spec_format = input_format == kOpFormat_DEFAULT ? dst_format : input_format;
auto input_node_out_shape = common::AnfAlgo::GetOutputDetailShape(input_node, 0); auto input_node_out_shape = AnfAlgo::GetOutputDetailShape(input_node, 0);
MS_EXCEPTION_IF_NULL(input_node_out_shape); MS_EXCEPTION_IF_NULL(input_node_out_shape);
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>(); auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(out_shape_ptr); MS_EXCEPTION_IF_NULL(out_shape_ptr);
@ -365,7 +365,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
MS_EXCEPTION_IF_NULL(trans_node); MS_EXCEPTION_IF_NULL(trans_node);
auto infer_type = common::AnfAlgo::GetOutputInferDataType(input, 0); auto infer_type = common::AnfAlgo::GetOutputInferDataType(input, 0);
auto out_shape_base = common::AnfAlgo::GetOutputDetailShape(input, 0); auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
MS_EXCEPTION_IF_NULL(out_shape_base); MS_EXCEPTION_IF_NULL(out_shape_base);
ShapeVector out_shape; ShapeVector out_shape;
bool is_dynamic_shape = false; bool is_dynamic_shape = false;
@ -555,8 +555,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
origin_type = common::AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); origin_type = common::AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
} }
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const abstract::BaseShapePtr origin_shape = const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
// In graph kernel, we check parameter, // In graph kernel, we check parameter,
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast. // the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {

View File

@ -196,7 +196,7 @@ AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::K
MS_EXCEPTION_IF_NULL(tuple_item); MS_EXCEPTION_IF_NULL(tuple_item);
common::AnfAlgo::SetOutputTypeAndDetailShape( common::AnfAlgo::SetOutputTypeAndDetailShape(
{common::AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, {common::AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)},
{common::AnfAlgo::GetOutputDetailShape(buffer_fusion_kernel, output_index)}, tuple_item.get()); {AnfAlgo::GetOutputDetailShape(buffer_fusion_kernel, output_index)}, tuple_item.get());
return tuple_item; return tuple_item;
} }
@ -582,7 +582,7 @@ bool UbPatternFusion::ReplaceFusionOp(mindspore::HashMap<int64_t, BufferFusionIn
size_t out_num = AnfAlgo::GetOutputTensorNum(out_node); size_t out_num = AnfAlgo::GetOutputTensorNum(out_node);
for (size_t idx = 0; idx < out_num; ++idx) { for (size_t idx = 0; idx < out_num; ++idx) {
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(out_node, idx)); (void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(out_node, idx));
(void)shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(out_node, idx)); (void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(out_node, idx));
} }
} }
if (types.empty() || shapes.empty()) { if (types.empty() || shapes.empty()) {

View File

@ -231,7 +231,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
for (size_t i = 0; i < output_num; ++i) { for (size_t i = 0; i < output_num; ++i) {
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, i); auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, i);
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]}, common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]},
{common::AnfAlgo::GetOutputDetailShape(node, i)}, tuple_getitem.get()); {AnfAlgo::GetOutputDetailShape(node, i)}, tuple_getitem.get());
(void)new_outputs.emplace_back(std::move(tuple_getitem)); (void)new_outputs.emplace_back(std::move(tuple_getitem));
} }
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size); return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);

View File

@ -64,7 +64,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
for (size_t input_idx = 0; input_idx < input_num; input_idx++) { for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_idx); auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_idx);
auto origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); auto origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
auto origin_shape_base_ptr = common::AnfAlgo::GetPrevNodeOutputDetailShape(cnode, input_idx); auto origin_shape_base_ptr = AnfAlgo::GetPrevNodeOutputDetailShape(cnode, input_idx);
auto origin_shape_ptr = origin_shape_base_ptr->cast<abstract::ShapePtr>(); auto origin_shape_ptr = origin_shape_base_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(origin_shape_ptr); MS_EXCEPTION_IF_NULL(origin_shape_ptr);
auto origin_shape = origin_shape_ptr->shape(); auto origin_shape = origin_shape_ptr->shape();

View File

@ -54,7 +54,7 @@ CNodePtr InsertForInput(const FuncGraphPtr &func_graph, const CNodePtr &node, co
auto in_node = common::AnfAlgo::GetInputNode(node, 0); auto in_node = common::AnfAlgo::GetInputNode(node, 0);
auto type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); auto type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
auto in_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, 0); auto in_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
auto transpose_out_shape = InferTransposeOutputShape(in_shape, perm); auto transpose_out_shape = InferTransposeOutputShape(in_shape, perm);
auto ori_out_types = AnfAlgo::GetAllOutputInferDataTypes(node); auto ori_out_types = AnfAlgo::GetAllOutputInferDataTypes(node);
@ -112,7 +112,7 @@ AnfNodePtr InsertForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_
(void)transpose_inputs.push_back(tuple_getitem); (void)transpose_inputs.push_back(tuple_getitem);
(void)transpose_inputs.push_back(perm_value_input); (void)transpose_inputs.push_back(perm_value_input);
auto shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx); auto shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
auto type = common::AnfAlgo::GetOutputInferDataType(node, output_idx); auto type = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
auto transpose_out_shape = InferTransposeOutputShape(shape, perm); auto transpose_out_shape = InferTransposeOutputShape(shape, perm);

View File

@ -113,7 +113,7 @@ AnfNodePtr DealRefOutput::AddAdditionalToRefOutput(const FuncGraphPtr &func_grap
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
auto cur_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_index); auto cur_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_index);
auto detail_shape = common::AnfAlgo::GetOutputDetailShape(cnode, output_index); auto detail_shape = AnfAlgo::GetOutputDetailShape(cnode, output_index);
// insert trans // insert trans
if (origin_format != cur_format && cur_shape.size() > 1) { if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>(); auto kernel_select = std::make_shared<KernelSelect>();

View File

@ -35,7 +35,7 @@ bool IsDepthwiseCase(const AnfNodePtr &node, size_t index, const std::string &fo
if (format != kOpFormat_FRAC_Z) { if (format != kOpFormat_FRAC_Z) {
return false; return false;
} }
abstract::BaseShapePtr base_shape = common::AnfAlgo::GetOutputDetailShape(node, index); abstract::BaseShapePtr base_shape = AnfAlgo::GetOutputDetailShape(node, index);
MS_EXCEPTION_IF_NULL(base_shape); MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) { if (base_shape->isa<abstract::Shape>()) {
auto shape_ptr = base_shape->cast<abstract::ShapePtr>(); auto shape_ptr = base_shape->cast<abstract::ShapePtr>();

View File

@ -51,7 +51,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
AnfNodePtr replace_node = nullptr; AnfNodePtr replace_node = nullptr;
const auto origin_shape = common::AnfAlgo::GetOutputDetailShape(cnode, output_idx); const auto origin_shape = AnfAlgo::GetOutputDetailShape(cnode, output_idx);
const auto origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, output_idx); const auto origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, output_idx);
auto idx = NewValueNode(SizeToLong(output_idx)); auto idx = NewValueNode(SizeToLong(output_idx));
MS_EXCEPTION_IF_NULL(idx); MS_EXCEPTION_IF_NULL(idx);
@ -105,7 +105,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &o
// Single output, output is not TUPLE // Single output, output is not TUPLE
if (!cnode->Type()->isa<Tuple>()) { if (!cnode->Type()->isa<Tuple>()) {
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
const abstract::BaseShapePtr origin_shape = common::AnfAlgo::GetOutputDetailShape(cnode, 0); const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(cnode, 0);
const TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0); const TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode; AnfNodePtr replace_node = cnode;

View File

@ -31,8 +31,8 @@ bool IsDepthwiseCase(const CNodePtr &node, size_t index, const std::string &form
if (format != kOpFormat_FRAC_Z) { if (format != kOpFormat_FRAC_Z) {
return false; return false;
} }
abstract::BaseShapePtr base_shape = is_tuple ? common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index) abstract::BaseShapePtr base_shape =
: common::AnfAlgo::GetOutputDetailShape(node, index); is_tuple ? AnfAlgo::GetPrevNodeOutputDetailShape(node, index) : AnfAlgo::GetOutputDetailShape(node, index);
MS_EXCEPTION_IF_NULL(base_shape); MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) { if (base_shape->isa<abstract::Shape>()) {
auto shape_ptr = base_shape->cast<abstract::ShapePtr>(); auto shape_ptr = base_shape->cast<abstract::ShapePtr>();

View File

@ -125,7 +125,7 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);
auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast, 0); auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast, 0);
auto cast_shape = common::AnfAlgo::GetOutputDetailShape(cast, 0); auto cast_shape = AnfAlgo::GetOutputDetailShape(cast, 0);
std::vector<abstract::BaseShapePtr> shapes; std::vector<abstract::BaseShapePtr> shapes;
std::vector<TypeId> types; std::vector<TypeId> types;
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
@ -135,7 +135,7 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
(void)types.emplace_back(cast_dtype); (void)types.emplace_back(cast_dtype);
continue; continue;
} }
(void)shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(cnode, index)); (void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(cnode, index));
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, index)); (void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, index));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());

View File

@ -20,6 +20,7 @@
#include "include/common/utils/anfalgo.h" #include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/core_ops.h" #include "mindspore/core/ops/core_ops.h"
#include "backend/common/optimizer/optimizer.h" #include "backend/common/optimizer/optimizer.h"
#include "backend/common/session/anf_runtime_algorithm.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -47,8 +48,7 @@ AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, co
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input}); AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)}, common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
cast.get());
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast); common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast);
cast->set_scope(input->scope()); cast->set_scope(input->scope());
return cast; return cast;

View File

@ -246,7 +246,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
tile_node->set_scope(mul_node->scope()); tile_node->set_scope(mul_node->scope());
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)},
{common::AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1)}, {AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1)},
tile_node.get()); tile_node.get());
// Feature map set // Feature map set
std::vector<size_t> feature_map_input_indexs; std::vector<size_t> feature_map_input_indexs;

View File

@ -44,8 +44,7 @@ void BatchNormGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, co
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
@ -79,7 +78,7 @@ void BatchNormGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, co
bn_reduce_grad->set_scope(bn_grad_node->scope()); bn_reduce_grad->set_scope(bn_grad_node->scope());
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);

View File

@ -41,7 +41,7 @@ AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_grap
MS_EXCEPTION_IF_NULL(new_cnode); MS_EXCEPTION_IF_NULL(new_cnode);
auto predict_input = cnode->inputs()[kIndex1]; auto predict_input = cnode->inputs()[kIndex1];
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)}; auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)}; auto new_node_shape = {AnfAlgo::GetOutputDetailShape(predict_input, 0)};
// The kAttrReduction is necessary for InferShape of BCEWithLogitsLoss op // The kAttrReduction is necessary for InferShape of BCEWithLogitsLoss op
common::AnfAlgo::SetNodeAttr(kAttrReduction, MakeValue("none"), new_cnode); common::AnfAlgo::SetNodeAttr(kAttrReduction, MakeValue("none"), new_cnode);
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
@ -60,7 +60,7 @@ AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_grap
} }
auto reduce_node = NewCNode(reduce_inputs, func_graph); auto reduce_node = NewCNode(reduce_inputs, func_graph);
MS_EXCEPTION_IF_NULL(reduce_node); MS_EXCEPTION_IF_NULL(reduce_node);
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shape = {AnfAlgo::GetOutputDetailShape(node, 0)};
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0); auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
if (type == kNumberTypeFloat16) { if (type == kNumberTypeFloat16) {
common::AnfAlgo::SetOutputTypeAndDetailShape({kNumberTypeFloat32}, shape, reduce_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape({kNumberTypeFloat32}, shape, reduce_node.get());

View File

@ -45,8 +45,7 @@ void BnGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNo
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) { if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {
@ -86,7 +85,7 @@ void BnGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNo
bn_reduce_grad->set_scope(bn_grad_node->scope()); bn_reduce_grad->set_scope(bn_grad_node->scope());
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) { if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {

View File

@ -56,8 +56,7 @@ bool BnSplit::CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const C
bn_training_reduce->set_kernel_info(kernel_info); bn_training_reduce->set_kernel_info(kernel_info);
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1), auto types = {common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1),
common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1)}; common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_cnode, 1), auto shapes = {AnfAlgo::GetOutputDetailShape(bn_cnode, 1), AnfAlgo::GetOutputDetailShape(bn_cnode, 1)};
common::AnfAlgo::GetOutputDetailShape(bn_cnode, 1)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_training_reduce.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_training_reduce.get());
bn_training_reduce->set_scope(bn_cnode->scope()); bn_training_reduce->set_scope(bn_cnode->scope());
if (is_dynamic) { if (is_dynamic) {
@ -205,8 +204,7 @@ AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input}); AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)}, common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
cast.get());
common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
cast->set_scope(input->scope()); cast->set_scope(input->scope());
return cast; return cast;

View File

@ -50,7 +50,7 @@ AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const
if (axis_from_attr < 0) { if (axis_from_attr < 0) {
axis_from_attr += SizeToLong(input_shape.size()); axis_from_attr += SizeToLong(input_shape.size());
} }
auto output_shape_ptr = common::AnfAlgo::GetOutputDetailShape(origin_concat_cnode, 0); auto output_shape_ptr = AnfAlgo::GetOutputDetailShape(origin_concat_cnode, 0);
MS_EXCEPTION_IF_NULL(output_shape_ptr); MS_EXCEPTION_IF_NULL(output_shape_ptr);
auto output_shapeptr = output_shape_ptr->cast<abstract::ShapePtr>(); auto output_shapeptr = output_shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(output_shapeptr); MS_EXCEPTION_IF_NULL(output_shapeptr);
@ -63,7 +63,7 @@ AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const
auto axis = LongToSize(axis_from_attr); auto axis = LongToSize(axis_from_attr);
output_shape[axis] = 0; output_shape[axis] = 0;
for (size_t i = begin_index; i < begin_index + offset; ++i) { for (size_t i = begin_index; i < begin_index + offset; ++i) {
auto last_input_shape_ptr = common::AnfAlgo::GetPrevNodeOutputDetailShape(origin_concat_cnode, i - 1); auto last_input_shape_ptr = AnfAlgo::GetPrevNodeOutputDetailShape(origin_concat_cnode, i - 1);
MS_EXCEPTION_IF_NULL(last_input_shape_ptr); MS_EXCEPTION_IF_NULL(last_input_shape_ptr);
auto last_input_shapeptr = last_input_shape_ptr->cast<abstract::ShapePtr>(); auto last_input_shapeptr = last_input_shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(last_input_shapeptr); MS_EXCEPTION_IF_NULL(last_input_shapeptr);

View File

@ -20,6 +20,7 @@
#include "include/common/utils/anfalgo.h" #include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/core_ops.h" #include "mindspore/core/ops/core_ops.h"
#include "backend/common/optimizer/optimizer.h" #include "backend/common/optimizer/optimizer.h"
#include "backend/common/session/anf_runtime_algorithm.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -46,8 +47,7 @@ AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, co
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input}); AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)}, common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
cast.get());
common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), cast); common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), cast);
cast->set_scope(input->scope()); cast->set_scope(input->scope());
return cast; return cast;

View File

@ -48,8 +48,8 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackpropV2(const FuncGraphPtr
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop); MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
layer_norm_x_backprop->set_scope(layer_norm_grad->scope()); layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32}; auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0), auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0),
common::AnfAlgo::GetPrevNodeOutputDetailShape(layer_norm_grad, 1)}; AnfAlgo::GetPrevNodeOutputDetailShape(layer_norm_grad, 1)};
if (is_dynamic) { if (is_dynamic) {
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), layer_norm_x_backprop); common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), layer_norm_x_backprop); common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
@ -78,8 +78,8 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackpropV2(
} }
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex), auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex),
common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)}; common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex), auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex),
common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)}; AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get());
// get device shape of LayerNormGrad's 5th Input, and convert it to attr // get device shape of LayerNormGrad's 5th Input, and convert it to attr

View File

@ -40,7 +40,7 @@ AnfNodePtr PackFission::CreateNewPack(const FuncGraphPtr &func_graph, const CNod
std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)}; std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)};
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack); common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack);
// infer shape // infer shape
auto output_shape_ptr = common::AnfAlgo::GetOutputDetailShape(origin_pack_cnode, 0); auto output_shape_ptr = AnfAlgo::GetOutputDetailShape(origin_pack_cnode, 0);
MS_EXCEPTION_IF_NULL(output_shape_ptr); MS_EXCEPTION_IF_NULL(output_shape_ptr);
auto output_shape = output_shape_ptr->cast<abstract::ShapePtr>(); auto output_shape = output_shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(output_shape); MS_EXCEPTION_IF_NULL(output_shape);

View File

@ -54,7 +54,7 @@ const AnfNodePtr ReduceSumFission::Process(const FuncGraphPtr &graph, const AnfN
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode); auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
auto keep_dims = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); auto keep_dims = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims);
auto out_shape = common::AnfAlgo::GetOutputDetailShape(cnode, 0); auto out_shape = AnfAlgo::GetOutputDetailShape(cnode, 0);
std::vector<int64_t> inp_axis; std::vector<int64_t> inp_axis;
auto axis_value = prim->GetAttr(kAttrAxis); auto axis_value = prim->GetAttr(kAttrAxis);
MS_EXCEPTION_IF_NULL(axis_value); MS_EXCEPTION_IF_NULL(axis_value);

View File

@ -32,9 +32,9 @@ bool IsDepthwiseCase(const CNodePtr &node, const std::string &input_format, cons
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape; abstract::BaseShapePtr base_shape;
if (input_format == kOpFormat_FRAC_Z && output_format == kOpFormat_DEFAULT) { if (input_format == kOpFormat_FRAC_Z && output_format == kOpFormat_DEFAULT) {
base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, 0); base_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
} else if (input_format == kOpFormat_DEFAULT && output_format == kOpFormat_FRAC_Z) { } else if (input_format == kOpFormat_DEFAULT && output_format == kOpFormat_FRAC_Z) {
base_shape = common::AnfAlgo::GetOutputDetailShape(node, 0); base_shape = AnfAlgo::GetOutputDetailShape(node, 0);
} else { } else {
return false; return false;
} }

View File

@ -330,8 +330,8 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(add1); MS_EXCEPTION_IF_NULL(add1);
auto types = {common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0), auto types = {common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
common::AnfAlgo::GetOutputInferDataType(sub0, 0)}; common::AnfAlgo::GetOutputInferDataType(sub0, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(add1, 0), common::AnfAlgo::GetOutputDetailShape(add0, 0), auto shapes = {AnfAlgo::GetOutputDetailShape(add1, 0), AnfAlgo::GetOutputDetailShape(add0, 0),
common::AnfAlgo::GetOutputDetailShape(sub0, 0)}; AnfAlgo::GetOutputDetailShape(sub0, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
std::vector<AnfNodePtr> fusion_node_outputs; std::vector<AnfNodePtr> fusion_node_outputs;

View File

@ -141,8 +141,8 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
if (height % output_h != 0 || width % output_w != 0) { if (height % output_h != 0 || width % output_w != 0) {
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0), kNumberTypeInt64}; auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0), kNumberTypeInt64};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0), auto shapes = {AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0),
common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)}; AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adaptive_max_pool2d.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adaptive_max_pool2d.get());
std::vector<AnfNodePtr> multi_outputs; std::vector<AnfNodePtr> multi_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, adaptive_max_pool2d, kAdaptiveMaxpool2DOutputNumber, &multi_outputs); CreateMultipleOutputsOfAnfNode(func_graph, adaptive_max_pool2d, kAdaptiveMaxpool2DOutputNumber, &multi_outputs);
@ -159,7 +159,7 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
adaptive_max_pool2d->inputs().end()); adaptive_max_pool2d->inputs().end());
auto pooling = NewCNode(pooling_inputs, kernel_graph); auto pooling = NewCNode(pooling_inputs, kernel_graph);
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, pooling.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, pooling.get());
pooling->set_scope(adaptive_max_pool2d->scope()); pooling->set_scope(adaptive_max_pool2d->scope());
SetNodeAttr(pooling, height_attr, width_attr); SetNodeAttr(pooling, height_attr, width_attr);

View File

@ -95,8 +95,7 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
MS_EXCEPTION_IF_NULL(fused_dbn_dw); MS_EXCEPTION_IF_NULL(fused_dbn_dw);
auto types = {common::AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0), auto types = {common::AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0),
common::AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)}; common::AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bnreduce_grad, 0), auto shapes = {AnfAlgo::GetOutputDetailShape(bnreduce_grad, 0), AnfAlgo::GetOutputDetailShape(conv_back_filter, 0)};
common::AnfAlgo::GetOutputDetailShape(conv_back_filter, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_dbn_dw.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_dbn_dw.get());
fused_dbn_dw->set_scope(bnreduce_grad->scope()); fused_dbn_dw->set_scope(bnreduce_grad->scope());
common::AnfAlgo::CopyNodeAttr(kAttrFilterSizes, conv_back_filter, fused_dbn_dw); common::AnfAlgo::CopyNodeAttr(kAttrFilterSizes, conv_back_filter, fused_dbn_dw);

View File

@ -65,7 +65,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
auto fusion_node = NewCNode(inputs, graph); auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node); MS_EXCEPTION_IF_NULL(fusion_node);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
fusion_node->set_scope(node->scope()); fusion_node->set_scope(node->scope());
return fusion_node; return fusion_node;

View File

@ -91,7 +91,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
auto clip_by_value = NewCNode(inputs, graph); auto clip_by_value = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(clip_by_value); MS_EXCEPTION_IF_NULL(clip_by_value);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, clip_by_value.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, clip_by_value.get());
clip_by_value->set_scope(node->scope()); clip_by_value->set_scope(node->scope());
return clip_by_value; return clip_by_value;

View File

@ -113,7 +113,7 @@ CNodePtr ConfusionMulGradFusion::CreateFusionNode(const FuncGraphPtr &graph, con
common::AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); common::AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
auto types = {common::AnfAlgo::GetOutputInferDataType(mul0, 0), auto types = {common::AnfAlgo::GetOutputInferDataType(mul0, 0),
common::AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; common::AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(mul0, 0), common::AnfAlgo::GetOutputDetailShape(reduce_sum, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(mul0, 0), AnfAlgo::GetOutputDetailShape(reduce_sum, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
return fusion_node; return fusion_node;
} }

View File

@ -183,8 +183,8 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1); std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0), auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add5, 0)}; common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add5, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0), common::AnfAlgo::GetOutputDetailShape(add0, 0), auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0), AnfAlgo::GetOutputDetailShape(add0, 0),
common::AnfAlgo::GetOutputDetailShape(add1, 0), common::AnfAlgo::GetOutputDetailShape(add5, 0)}; AnfAlgo::GetOutputDetailShape(add1, 0), AnfAlgo::GetOutputDetailShape(add5, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
std::vector<AnfNodePtr> fusion_node_outputs; std::vector<AnfNodePtr> fusion_node_outputs;

View File

@ -70,7 +70,7 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(lamb_update_with_lr); MS_EXCEPTION_IF_NULL(lamb_update_with_lr);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, lamb_update_with_lr.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, lamb_update_with_lr.get());
lamb_update_with_lr->set_scope(node->scope()); lamb_update_with_lr->set_scope(node->scope());
return lamb_update_with_lr; return lamb_update_with_lr;

View File

@ -83,7 +83,7 @@ const AnfNodePtr SoftmaxDropoutDoMaskV3Fusion::Process(const FuncGraphPtr &graph
MS_EXCEPTION_IF_NULL(softmax_dropout); MS_EXCEPTION_IF_NULL(softmax_dropout);
auto types = {common::AnfAlgo::GetOutputInferDataType(softmax, 0), auto types = {common::AnfAlgo::GetOutputInferDataType(softmax, 0),
common::AnfAlgo::GetOutputInferDataType(dropout, 0)}; common::AnfAlgo::GetOutputInferDataType(dropout, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(softmax, 0), common::AnfAlgo::GetOutputDetailShape(dropout, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(softmax, 0), AnfAlgo::GetOutputDetailShape(dropout, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, softmax_dropout.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, softmax_dropout.get());
softmax_dropout->set_scope(softmax->scope()); softmax_dropout->set_scope(softmax->scope());
common::AnfAlgo::CopyNodeAttr(kAttrAxis, softmax, softmax_dropout); common::AnfAlgo::CopyNodeAttr(kAttrAxis, softmax, softmax_dropout);

View File

@ -60,7 +60,7 @@ CNodePtr SquareSumFusion::GenerateSquareSumV1(const FuncGraphPtr &graph, const C
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
square_sumv1->set_kernel_info(kernel_info); square_sumv1->set_kernel_info(kernel_info);
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(sum, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(sum, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv1.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv1.get());
square_sumv1->set_scope(sum->scope()); square_sumv1->set_scope(sum->scope());
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1);
@ -82,7 +82,7 @@ CNodePtr SquareSumFusion::GenerateSquareSumV2(const FuncGraphPtr &graph, const C
auto square_sumv2 = NewCNode(square_sumv2_inputs, graph); auto square_sumv2 = NewCNode(square_sumv2_inputs, graph);
MS_EXCEPTION_IF_NULL(square_sumv2); MS_EXCEPTION_IF_NULL(square_sumv2);
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0), common::AnfAlgo::GetOutputInferDataType(square, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0), common::AnfAlgo::GetOutputInferDataType(square, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(sum, 0), common::AnfAlgo::GetOutputDetailShape(square, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(sum, 0), AnfAlgo::GetOutputDetailShape(square, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv2.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv2.get());
square_sumv2->set_scope(sum->scope()); square_sumv2->set_scope(sum->scope());
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2);

View File

@ -22,6 +22,7 @@
#include "include/common/utils/comm_manager.h" #include "include/common/utils/comm_manager.h"
#include "backend/common/optimizer/helper.h" #include "backend/common/optimizer/helper.h"
#include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/ops_info/ops_utils.h"
#include "backend/common/session/anf_runtime_algorithm.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -117,7 +118,7 @@ CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph, con
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end()); (void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
auto all_to_all_v = NewCNode(all_to_all_v_input, graph); auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
MS_EXCEPTION_IF_NULL(all_to_all_v); MS_EXCEPTION_IF_NULL(all_to_all_v);
auto single_shape = common::AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL); auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL); auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
std::vector<TypeId> dtypes(split_count, single_type); std::vector<TypeId> dtypes(split_count, single_type);
std::vector<BaseShapePtr> shapes(split_count, single_shape); std::vector<BaseShapePtr> shapes(split_count, single_shape);

View File

@ -49,11 +49,10 @@ AnfNodePtr BuildBatchNormGrad(const PatternMap &m, const AnfNodePtr &new_node) {
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2UL), common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2UL),
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 3UL), common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 3UL),
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 4UL)}; common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 4UL)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0UL), auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0UL), AnfAlgo::GetOutputDetailShape(bn_grad_node, 1UL),
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1UL), AnfAlgo::GetOutputDetailShape(bn_grad_node, 2UL),
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2UL), AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 3UL),
common::AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 3UL), AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 4UL)};
common::AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 4UL)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_bn_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_bn_grad.get());
common::AnfAlgo::CopyNodeAttrs(bn_grad_node, new_bn_grad); common::AnfAlgo::CopyNodeAttrs(bn_grad_node, new_bn_grad);
common::AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad); common::AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad);

View File

@ -64,7 +64,7 @@ CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolWithArgmax(const FuncGraphPtr &
// MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output // MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output
TypeId argmax_dtype = kNumberTypeUInt16; TypeId argmax_dtype = kNumberTypeUInt16;
auto types = {common::AnfAlgo::GetOutputInferDataType(maxpool, 0UL), argmax_dtype}; auto types = {common::AnfAlgo::GetOutputInferDataType(maxpool, 0UL), argmax_dtype};
auto out_shape = common::AnfAlgo::GetOutputDetailShape(maxpool, 0UL); auto out_shape = AnfAlgo::GetOutputDetailShape(maxpool, 0UL);
std::vector<BaseShapePtr> shapes = {out_shape, out_shape}; std::vector<BaseShapePtr> shapes = {out_shape, out_shape};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, maxpool_argmax.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, maxpool_argmax.get());
return maxpool_argmax; return maxpool_argmax;

View File

@ -835,7 +835,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids); auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
auto centerx_dtype = common::AnfAlgo::GetOutputInferDataType(centerx, 0UL); auto centerx_dtype = common::AnfAlgo::GetOutputInferDataType(centerx, 0UL);
auto centerx_shape = common::AnfAlgo::GetOutputInferShape(centerx, 0UL); auto centerx_shape = common::AnfAlgo::GetOutputInferShape(centerx, 0UL);
auto base_shape = common::AnfAlgo::GetOutputDetailShape(centerx, 0UL); auto base_shape = AnfAlgo::GetOutputDetailShape(centerx, 0UL);
// empty // empty
int64_t all_to_all_output_num = int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; }); std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });

View File

@ -294,9 +294,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
auto tile_node = pass.NewCNode(tile_inputs, graph); auto tile_node = pass.NewCNode(tile_inputs, graph);
MS_EXCEPTION_IF_NULL(tile_node); MS_EXCEPTION_IF_NULL(tile_node);
tile_node->set_scope(mul_node->scope()); tile_node->set_scope(mul_node->scope());
common::AnfAlgo::SetOutputTypeAndDetailShape( common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)},
{common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)}, {AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1UL)},
{common::AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1UL)}, tile_node.get()); tile_node.get());
if (is_convert_const_to_attr) { if (is_convert_const_to_attr) {
common::AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node); common::AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
} }

View File

@ -134,8 +134,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_index); auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_index);
MS_EXCEPTION_IF_NULL(cur_input); MS_EXCEPTION_IF_NULL(cur_input);
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const abstract::BaseShapePtr origin_shape = const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) { if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) {
auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape); auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape);
@ -199,7 +198,7 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &
auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(func_output, i); auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(func_output, i);
const std::string dev_fmt = AnfAlgo::GetPrevNodeOutputFormat(func_output, i); const std::string dev_fmt = AnfAlgo::GetPrevNodeOutputFormat(func_output, i);
if (infer_type != device_type && device_type != kTypeUnknown) { if (infer_type != device_type && device_type != kTypeUnknown) {
const abstract::BaseShapePtr origin_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(func_output_node, i); const abstract::BaseShapePtr origin_shape = AnfAlgo::GetPrevNodeOutputDetailShape(func_output_node, i);
auto cast = AddCastOpNodeToGraph(func_graph, input_node, dev_fmt, device_type, infer_type, origin_shape); auto cast = AddCastOpNodeToGraph(func_graph, input_node, dev_fmt, device_type, infer_type, origin_shape);
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(func_output->scope()); cast->set_scope(func_output->scope());

View File

@ -86,7 +86,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
auto transpose_op = graph->NewCNode(transpose_input); auto transpose_op = graph->NewCNode(transpose_input);
// 3.Set the output info of transpose. // 3.Set the output info of transpose.
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))}; auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))};
auto transpose_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))}; auto transpose_shape = {AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))};
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get());
// 4. Set the new edge of transpose op. // 4. Set the new edge of transpose op.
FuncGraphManagerPtr manager = graph->manager(); FuncGraphManagerPtr manager = graph->manager();

View File

@ -131,7 +131,7 @@ bool PrintValueType::Run(const FuncGraphPtr &graph) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i)); types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(cnode, i)); shapes.push_back(AnfAlgo::GetOutputDetailShape(cnode, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
// add build info // add build info

View File

@ -165,7 +165,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
auto adam = graph->NewCNode(inputs); auto adam = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam); MS_EXCEPTION_IF_NULL(adam);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam.get());
adam->set_scope(node->scope()); adam->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(adam); auto build_info = GenerateKernelBuildInfo(adam);

View File

@ -170,7 +170,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
auto adam_weight_decay = graph->NewCNode(inputs); auto adam_weight_decay = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam_weight_decay); MS_EXCEPTION_IF_NULL(adam_weight_decay);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam_weight_decay.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam_weight_decay.get());
adam_weight_decay->set_scope(node->scope()); adam_weight_decay->set_scope(node->scope());

View File

@ -89,7 +89,7 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
auto add_relugrad = graph->NewCNode(inputs); auto add_relugrad = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add_relugrad); MS_EXCEPTION_IF_NULL(add_relugrad);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relugrad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relugrad.get());
add_relugrad->set_scope(node->scope()); add_relugrad->set_scope(node->scope());

View File

@ -92,7 +92,7 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
size_t output_num = AnfAlgo::GetOutputElementNum(node); size_t output_num = AnfAlgo::GetOutputElementNum(node);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i)); types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(node, i)); shapes.push_back(AnfAlgo::GetOutputDetailShape(node, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relu.get());

View File

@ -100,7 +100,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
MS_EXCEPTION_IF_NULL(all_to_all_v); MS_EXCEPTION_IF_NULL(all_to_all_v);
// Prepare dtypes, shapes and ranks vectors. // Prepare dtypes, shapes and ranks vectors.
auto single_shape = common::AnfAlgo::GetOutputDetailShape(split_outputs[0], 0); auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0);
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0); auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
std::vector<TypeId> dtypes(split_count, single_type); std::vector<TypeId> dtypes(split_count, single_type);
std::vector<BaseShapePtr> shapes(split_count, single_shape); std::vector<BaseShapePtr> shapes(split_count, single_shape);

View File

@ -89,7 +89,7 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co
auto replace_node = graph->NewCNode(inputs); auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope()); replace_node->set_scope(node->scope());
return replace_node; return replace_node;

View File

@ -61,7 +61,7 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra
auto replace_node = graph->NewCNode(inputs); auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope()); replace_node->set_scope(node->scope());
return replace_node; return replace_node;

View File

@ -127,7 +127,7 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
auto replace_node = graph->NewCNode(inputs); auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope()); replace_node->set_scope(node->scope());
return replace_node; return replace_node;

View File

@ -149,7 +149,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu); common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);

View File

@ -74,11 +74,11 @@ void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad
auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad); auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
for (size_t i = 0; i < output_num; ++i) { for (size_t i = 0; i < output_num; ++i) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(bn_grad, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(bn_grad, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(bn_grad, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(bn_grad, i));
} }
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, 0)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, 0));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(relu_grad, 0)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(relu_grad, 0));
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, bn_add_relu_grad.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
} }

View File

@ -101,7 +101,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get());
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu); common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu);

View File

@ -101,7 +101,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
auto output_num = AnfAlgo::GetOutputTensorNum(node); auto output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(node, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(node, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
common::AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu); common::AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);

View File

@ -41,7 +41,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto predict_input = cnode->inputs()[1]; auto predict_input = cnode->inputs()[1];
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)}; auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)}; auto new_node_shape = {AnfAlgo::GetOutputDetailShape(predict_input, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
// Add reduce node // Add reduce node
@ -69,7 +69,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
auto reduce_node = func_graph->NewCNode(reduce_inputs); auto reduce_node = func_graph->NewCNode(reduce_inputs);
MS_EXCEPTION_IF_NULL(reduce_node); MS_EXCEPTION_IF_NULL(reduce_node);
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0); auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shape = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get());
common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node); common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
reduce_node->set_scope(cnode->scope()); reduce_node->set_scope(cnode->scope());

View File

@ -155,7 +155,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
idx->set_abstract(abstract_scalar); idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem); MS_EXCEPTION_IF_NULL(tuple_getitem);
auto shape = common::AnfAlgo::GetOutputDetailShape(node, i); auto shape = AnfAlgo::GetOutputDetailShape(node, i);
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]}, {shape}, tuple_getitem.get()); common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]}, {shape}, tuple_getitem.get());
new_outputs.emplace_back(std::move(tuple_getitem)); new_outputs.emplace_back(std::move(tuple_getitem));
} }

View File

@ -158,7 +158,7 @@ void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
std::vector<BaseShapePtr> shapes; std::vector<BaseShapePtr> shapes;
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
types.emplace_back(common::AnfAlgo::GetOutputInferDataType(src, i)); types.emplace_back(common::AnfAlgo::GetOutputInferDataType(src, i));
shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(src, i)); shapes.emplace_back(AnfAlgo::GetOutputDetailShape(src, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dst.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dst.get());
} }

View File

@ -38,7 +38,7 @@ void InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, con
auto cast = graph->NewCNode(inputs); auto cast = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast); common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast);
auto cast_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(node, i)}; auto cast_shape = {AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get()); common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
FuncGraphManagerPtr manager = graph->manager(); FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
@ -110,7 +110,7 @@ bool InsertCastGPU::Run(const FuncGraphPtr &graph) {
auto output_types = std::vector<TypeId>(output_num, kNumberTypeFloat32); auto output_types = std::vector<TypeId>(output_num, kNumberTypeFloat32);
std::vector<BaseShapePtr> output_shapes; std::vector<BaseShapePtr> output_shapes;
for (size_t output_index = 0; output_index < output_num; ++output_index) { for (size_t output_index = 0; output_index < output_num; ++output_index) {
auto shape = common::AnfAlgo::GetOutputDetailShape(node, output_index); auto shape = AnfAlgo::GetOutputDetailShape(node, output_index);
(void)output_shapes.emplace_back(shape); (void)output_shapes.emplace_back(shape);
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(output_types, output_shapes, node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(output_types, output_shapes, node.get());

View File

@ -130,7 +130,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
MS_EXCEPTION_IF_NULL(transpose_op); MS_EXCEPTION_IF_NULL(transpose_op);
// 3.Set the output info of transpose. // 3.Set the output info of transpose.
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index); auto base_shape = AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index);
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get());
// 4. Set the new edge of transpose op. // 4. Set the new edge of transpose op.

View File

@ -100,7 +100,7 @@ const AnfNodePtr MatMulBiasAddFusion::Process(const FuncGraphPtr &graph, const A
// Copy Abstract and KernelBuildInfo. // Copy Abstract and KernelBuildInfo.
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_node.get());
common::AnfAlgo::CopyNodeAttrs(matmul, fused_node); common::AnfAlgo::CopyNodeAttrs(matmul, fused_node);
fused_node->set_scope(node->scope()); fused_node->set_scope(node->scope());

View File

@ -91,7 +91,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu); common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);

View File

@ -191,7 +191,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i)); types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(cnode, i)); shapes.push_back(AnfAlgo::GetOutputDetailShape(cnode, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, print_fused.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, print_fused.get());
// add build info // add build info

View File

@ -83,7 +83,7 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
auto element_num = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>()); auto element_num = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
std::vector<int64_t> mask_shape = {(element_num + kBitPerUInt - 1) / kBitPerUInt}; std::vector<int64_t> mask_shape = {(element_num + kBitPerUInt - 1) / kBitPerUInt};
std::vector<BaseShapePtr> shapes = {common::AnfAlgo::GetOutputDetailShape(relu, 0), std::vector<BaseShapePtr> shapes = {AnfAlgo::GetOutputDetailShape(relu, 0),
std::make_shared<abstract::Shape>(mask_shape)}; std::make_shared<abstract::Shape>(mask_shape)};
auto types = {common::AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32}; auto types = {common::AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());
@ -110,7 +110,7 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad); size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
types.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, i)); types.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, i));
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(relu_grad, i)); shapes.push_back(AnfAlgo::GetOutputDetailShape(relu_grad, i));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());

View File

@ -44,7 +44,7 @@ AnfNodePtr BuildAdd(const PatternMap &m, const AnfNodePtr &default_node) {
std::vector<TypeId> outputs_type; std::vector<TypeId> outputs_type;
std::vector<BaseShapePtr> outputs_shape; std::vector<BaseShapePtr> outputs_shape;
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(m.Get(A), 0)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(m.Get(A), 0));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(m.Get(A), 0)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(m.Get(A), 0));
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, default_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, default_node.get());
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(m.Get(m_addn)), default_node.get()); AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(m.Get(m_addn)), default_node.get());
return default_node; return default_node;

View File

@ -51,7 +51,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
auto output_num = AnfAlgo::GetOutputTensorNum(node); auto output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i)); outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(node, i)); outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(node, i));
} }
outputs_type[kGradIndex] = common::AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); outputs_type[kGradIndex] = common::AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);

View File

@ -39,7 +39,7 @@ void CopyGraphOutputTypeAndShape(const std::vector<session::KernelWithIndex> &gr
std::vector<BaseShapePtr> shapes; std::vector<BaseShapePtr> shapes;
for (const auto &item : graph_outputs) { for (const auto &item : graph_outputs) {
types.push_back(common::AnfAlgo::GetOutputInferDataType(item.first, item.second)); types.push_back(common::AnfAlgo::GetOutputInferDataType(item.first, item.second));
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(item.first, item.second)); shapes.push_back(AnfAlgo::GetOutputDetailShape(item.first, item.second));
} }
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, trt_node.get()); common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, trt_node.get());

View File

@ -625,30 +625,6 @@ inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
return shape_ptr->shape(); return shape_ptr->shape();
} }
namespace {
// Get the element shape of dynamic sequence shape.
abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
MS_LOG(EXCEPTION) << "Invalid dynamic shape in node:" << node->DebugString() << ".";
}
if (node->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
if (!node->abstract()->isa<abstract::AbstractSequence>()) {
MS_LOG(EXCEPTION) << "Not sequence abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(sequence_abs);
if (!sequence_abs->dynamic_len()) {
MS_LOG(EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
const auto &element_abs = sequence_abs->dynamic_len_element_abs();
MS_EXCEPTION_IF_NULL(element_abs);
return element_abs->BuildShape();
}
} // namespace
ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape, ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
size_t output_idx, bool is_real_squence_output) { size_t output_idx, bool is_real_squence_output) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -758,45 +734,6 @@ TypeId AnfAlgo::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t in
return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
} }
abstract::BaseShapePtr AnfAlgo::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) {
if (output_idx == 0) {
return base_shape;
}
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
<< "]." << trace::DumpSourceLines(node);
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (output_idx >= tuple_shape->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
<< " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
}
auto b_shp = (*tuple_shape)[output_idx];
if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
return b_shp;
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
<< "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
}
} else if (base_shape->isa<abstract::NoShape>()) {
return base_shape;
} else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
return GetDynamicSequenceShape(node, output_idx);
}
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
}
abstract::BaseShapePtr AnfAlgo::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
}
// set infer shapes and types of anf node // set infer shapes and types of anf node
void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types, void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) { const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
@ -1936,5 +1873,26 @@ bool AnfAlgo::IsReduceOp(const std::string &op_name) {
prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()}; prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()};
return reduce_op_type.find(op_name) != reduce_op_type.end(); return reduce_op_type.find(op_name) != reduce_op_type.end();
} }
abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
MS_LOG(EXCEPTION) << "Invalid dynamic shape in node:" << node->DebugString() << ".";
}
if (node->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
if (!node->abstract()->isa<abstract::AbstractSequence>()) {
MS_LOG(EXCEPTION) << "Not sequence abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(sequence_abs);
if (!sequence_abs->dynamic_len()) {
MS_LOG(EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
const auto &element_abs = sequence_abs->dynamic_len_element_abs();
MS_EXCEPTION_IF_NULL(element_abs);
return element_abs->BuildShape();
}
} // namespace common } // namespace common
} // namespace mindspore } // namespace mindspore

View File

@ -151,6 +151,7 @@ def init(backend_name=None):
raise RuntimeError("Parameter server and scheduler should use 'CPU' as backend instead of 'Ascend'") raise RuntimeError("Parameter server and scheduler should use 'CPU' as backend instead of 'Ascend'")
if _get_ps_context("worker_num") == 1: if _get_ps_context("worker_num") == 1:
GlobalComm.INITED = True GlobalComm.INITED = True
_set_elegant_exit_handle()
return return
if device_target != "Ascend": if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, " raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "