forked from mindspore-Ecosystem/mindspore
Move get detail shape to backend anfalgo.
This commit is contained in:
parent
13cfed8045
commit
6cfd711064
|
@ -633,10 +633,10 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
|
||||||
if (is_input) {
|
if (is_input) {
|
||||||
auto node_input = common::AnfAlgo::GetInputNode(node, 0);
|
auto node_input = common::AnfAlgo::GetInputNode(node, 0);
|
||||||
(void)new_cast_inputs.emplace_back(node_input);
|
(void)new_cast_inputs.emplace_back(node_input);
|
||||||
shape = common::AnfAlgo::GetOutputDetailShape(node_input, 0);
|
shape = AnfAlgo::GetOutputDetailShape(node_input, 0);
|
||||||
} else {
|
} else {
|
||||||
(void)new_cast_inputs.emplace_back(node);
|
(void)new_cast_inputs.emplace_back(node);
|
||||||
shape = common::AnfAlgo::GetOutputDetailShape(node, 0);
|
shape = AnfAlgo::GetOutputDetailShape(node, 0);
|
||||||
}
|
}
|
||||||
CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node});
|
CNodePtr new_cast = NewCNode(new_cast_inputs, func_graph, {node});
|
||||||
new_cast->set_scope(node->scope());
|
new_cast->set_scope(node->scope());
|
||||||
|
|
|
@ -549,7 +549,7 @@ bool AnfRuntimeAlgorithm::IsRealSquenceOutput(const AnfNodePtr &node) {
|
||||||
|
|
||||||
std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx,
|
std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t output_idx,
|
||||||
const std::string &format) {
|
const std::string &format) {
|
||||||
auto output_shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx);
|
auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
|
||||||
std::vector<int64_t> infer_shape;
|
std::vector<int64_t> infer_shape;
|
||||||
if (output_shape->isa<abstract::Shape>()) {
|
if (output_shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||||
|
@ -589,7 +589,7 @@ ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, si
|
||||||
|
|
||||||
std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx,
|
std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &node, size_t input_idx,
|
||||||
const std::string &format) {
|
const std::string &format) {
|
||||||
auto output_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx);
|
auto output_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, input_idx);
|
||||||
std::vector<int64_t> infer_shape;
|
std::vector<int64_t> infer_shape;
|
||||||
if (output_shape->isa<abstract::Shape>()) {
|
if (output_shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||||
|
@ -1733,7 +1733,49 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr
|
||||||
return {AnfAlgo::GetAbstractObjectType(node->abstract())};
|
return {AnfAlgo::GetAbstractObjectType(node->abstract())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
|
abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto base_shape = node->Shape();
|
||||||
|
MS_EXCEPTION_IF_NULL(base_shape);
|
||||||
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
|
if (output_idx == 0) {
|
||||||
|
return base_shape;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
|
||||||
|
<< "]." << trace::DumpSourceLines(node);
|
||||||
|
} else if (base_shape->isa<abstract::TupleShape>()) {
|
||||||
|
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||||
|
if (IsRealSquenceOutput(node)) {
|
||||||
|
return tuple_shape;
|
||||||
|
}
|
||||||
|
if (output_idx >= tuple_shape->size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
|
||||||
|
<< " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
|
||||||
|
}
|
||||||
|
auto b_shp = (*tuple_shape)[output_idx];
|
||||||
|
if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
|
||||||
|
return b_shp;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
|
||||||
|
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
|
||||||
|
<< "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
|
||||||
|
}
|
||||||
|
} else if (base_shape->isa<abstract::NoShape>()) {
|
||||||
|
return base_shape;
|
||||||
|
} else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
|
||||||
|
return common::AnfAlgo::GetDynamicSequenceShape(node, output_idx);
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
|
||||||
|
<< base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
|
||||||
|
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||||
|
return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<TypeId> outputs;
|
std::vector<TypeId> outputs;
|
||||||
auto out_nums = AnfAlgo::GetOutputElementNum(node);
|
auto out_nums = AnfAlgo::GetOutputElementNum(node);
|
||||||
|
@ -1746,7 +1788,7 @@ std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node)
|
||||||
|
|
||||||
// if input node is MakeTuple, find the PrevNodeNum recursively;
|
// if input node is MakeTuple, find the PrevNodeNum recursively;
|
||||||
// The monad node in the end is not included in the num;
|
// The monad node in the end is not included in the num;
|
||||||
size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) {
|
size_t AnfRuntimeAlgorithm::GetInputElementNum(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
@ -1772,7 +1814,7 @@ size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) {
|
||||||
return element_num;
|
return element_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfAlgo::SetDynamicAttrToPrim(const PrimitivePtr &prim) {
|
void AnfRuntimeAlgorithm::SetDynamicAttrToPrim(const PrimitivePtr &prim) {
|
||||||
prim->AddAttr(kAttrMutableKernel, MakeValue(true));
|
prim->AddAttr(kAttrMutableKernel, MakeValue(true));
|
||||||
prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true));
|
prim->AddAttr(kAttrInputIsDynamicShape, MakeValue(true));
|
||||||
prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true));
|
prim->AddAttr(kAttrOutputIsDynamicShape, MakeValue(true));
|
||||||
|
|
|
@ -225,6 +225,10 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
|
||||||
static size_t GetInputElementNum(const AnfNodePtr &node);
|
static size_t GetInputElementNum(const AnfNodePtr &node);
|
||||||
static bool IsRealSquenceOutput(const AnfNodePtr &node);
|
static bool IsRealSquenceOutput(const AnfNodePtr &node);
|
||||||
static void SetDynamicAttrToPrim(const PrimitivePtr &prim);
|
static void SetDynamicAttrToPrim(const PrimitivePtr &prim);
|
||||||
|
|
||||||
|
// Get output detail shape. These interfaces should take TUPLE output into consideration.
|
||||||
|
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
|
||||||
|
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -139,9 +139,7 @@ class COMMON_EXPORT AnfAlgo {
|
||||||
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
|
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
|
||||||
AnfNode *node, bool disable_dynamic_len = false);
|
AnfNode *node, bool disable_dynamic_len = false);
|
||||||
static void SetScalarTupleOutputInferType(const std::vector<TypeId> &types, const AnfNodePtr &node);
|
static void SetScalarTupleOutputInferType(const std::vector<TypeId> &types, const AnfNodePtr &node);
|
||||||
// get and set output shape ptr
|
// set output shape ptr
|
||||||
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
|
|
||||||
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
|
|
||||||
static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
||||||
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
|
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
|
||||||
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
|
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
|
||||||
|
@ -296,6 +294,9 @@ class COMMON_EXPORT AnfAlgo {
|
||||||
static bool HasTupleInput(const CNodePtr &node);
|
static bool HasTupleInput(const CNodePtr &node);
|
||||||
static bool HasDynamicTupleInput(const CNodePtr &node);
|
static bool HasDynamicTupleInput(const CNodePtr &node);
|
||||||
static bool IsReduceOp(const std::string &op_name);
|
static bool IsReduceOp(const std::string &op_name);
|
||||||
|
|
||||||
|
// Get the element shape of dynamic sequence shape.
|
||||||
|
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1416,8 +1416,8 @@ class AscendAutoMonadConverter {
|
||||||
std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
|
std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
|
||||||
source};
|
source};
|
||||||
auto cast_node = kernel_graph_->NewCNode(cast_inputs);
|
auto cast_node = kernel_graph_->NewCNode(cast_inputs);
|
||||||
auto origin_shape = common::AnfAlgo::GetOutputDetailShape(source, kFirstOutput);
|
auto origin_shape = AnfAlgo::GetOutputDetailShape(source, kFirstOutput);
|
||||||
auto shape = common::AnfAlgo::GetOutputDetailShape(target, kFirstOutput);
|
auto shape = AnfAlgo::GetOutputDetailShape(target, kFirstOutput);
|
||||||
if (!common::IsEqual(origin_shape, shape)) {
|
if (!common::IsEqual(origin_shape, shape)) {
|
||||||
MS_LOG(EXCEPTION) << "Assign: " << target->DebugString() << " and " << source->DebugString()
|
MS_LOG(EXCEPTION) << "Assign: " << target->DebugString() << " and " << source->DebugString()
|
||||||
<< " has different shape, source shape: " << origin_shape->ToString()
|
<< " has different shape, source shape: " << origin_shape->ToString()
|
||||||
|
|
|
@ -97,7 +97,7 @@ std::vector<int64_t> TbeJsonUtils::GetInputDeviceShapeForTbeBuild(const AnfNodeP
|
||||||
std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
|
std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
auto out_shape = common::AnfAlgo::GetOutputDetailShape(anf_node, real_idx);
|
auto out_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_idx);
|
||||||
MS_EXCEPTION_IF_NULL(out_shape);
|
MS_EXCEPTION_IF_NULL(out_shape);
|
||||||
if (out_shape->isa<abstract::Shape>()) {
|
if (out_shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = out_shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = out_shape->cast<abstract::ShapePtr>();
|
||||||
|
|
|
@ -71,8 +71,8 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
||||||
|
|
||||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||||
const std::string &format) {
|
const std::string &format) {
|
||||||
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
|
auto shape =
|
||||||
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
|
is_output ? AnfAlgo::GetOutputDetailShape(node, index) : AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
|
||||||
std::vector<int64_t> infer_shape;
|
std::vector<int64_t> infer_shape;
|
||||||
if (shape->isa<abstract::Shape>()) {
|
if (shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||||
|
|
|
@ -238,7 +238,7 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
|
||||||
<< input_format << " and dst format " << dst_format;
|
<< input_format << " and dst format " << dst_format;
|
||||||
}
|
}
|
||||||
std::string spec_format = input_format == kOpFormat_DEFAULT ? dst_format : input_format;
|
std::string spec_format = input_format == kOpFormat_DEFAULT ? dst_format : input_format;
|
||||||
auto input_node_out_shape = common::AnfAlgo::GetOutputDetailShape(input_node, 0);
|
auto input_node_out_shape = AnfAlgo::GetOutputDetailShape(input_node, 0);
|
||||||
MS_EXCEPTION_IF_NULL(input_node_out_shape);
|
MS_EXCEPTION_IF_NULL(input_node_out_shape);
|
||||||
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
|
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
||||||
|
@ -365,7 +365,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||||
MS_EXCEPTION_IF_NULL(trans_node);
|
MS_EXCEPTION_IF_NULL(trans_node);
|
||||||
auto infer_type = common::AnfAlgo::GetOutputInferDataType(input, 0);
|
auto infer_type = common::AnfAlgo::GetOutputInferDataType(input, 0);
|
||||||
|
|
||||||
auto out_shape_base = common::AnfAlgo::GetOutputDetailShape(input, 0);
|
auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
|
||||||
MS_EXCEPTION_IF_NULL(out_shape_base);
|
MS_EXCEPTION_IF_NULL(out_shape_base);
|
||||||
ShapeVector out_shape;
|
ShapeVector out_shape;
|
||||||
bool is_dynamic_shape = false;
|
bool is_dynamic_shape = false;
|
||||||
|
@ -555,8 +555,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
||||||
origin_type = common::AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
origin_type = common::AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
||||||
}
|
}
|
||||||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||||
const abstract::BaseShapePtr origin_shape =
|
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
||||||
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
|
||||||
// In graph kernel, we check parameter,
|
// In graph kernel, we check parameter,
|
||||||
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
||||||
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||||
|
|
|
@ -196,7 +196,7 @@ AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::K
|
||||||
MS_EXCEPTION_IF_NULL(tuple_item);
|
MS_EXCEPTION_IF_NULL(tuple_item);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(
|
common::AnfAlgo::SetOutputTypeAndDetailShape(
|
||||||
{common::AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)},
|
{common::AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)},
|
||||||
{common::AnfAlgo::GetOutputDetailShape(buffer_fusion_kernel, output_index)}, tuple_item.get());
|
{AnfAlgo::GetOutputDetailShape(buffer_fusion_kernel, output_index)}, tuple_item.get());
|
||||||
return tuple_item;
|
return tuple_item;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -582,7 +582,7 @@ bool UbPatternFusion::ReplaceFusionOp(mindspore::HashMap<int64_t, BufferFusionIn
|
||||||
size_t out_num = AnfAlgo::GetOutputTensorNum(out_node);
|
size_t out_num = AnfAlgo::GetOutputTensorNum(out_node);
|
||||||
for (size_t idx = 0; idx < out_num; ++idx) {
|
for (size_t idx = 0; idx < out_num; ++idx) {
|
||||||
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(out_node, idx));
|
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(out_node, idx));
|
||||||
(void)shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(out_node, idx));
|
(void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(out_node, idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (types.empty() || shapes.empty()) {
|
if (types.empty() || shapes.empty()) {
|
||||||
|
|
|
@ -231,7 +231,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, i);
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, i);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]},
|
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]},
|
||||||
{common::AnfAlgo::GetOutputDetailShape(node, i)}, tuple_getitem.get());
|
{AnfAlgo::GetOutputDetailShape(node, i)}, tuple_getitem.get());
|
||||||
(void)new_outputs.emplace_back(std::move(tuple_getitem));
|
(void)new_outputs.emplace_back(std::move(tuple_getitem));
|
||||||
}
|
}
|
||||||
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
|
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
|
||||||
|
|
|
@ -64,7 +64,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
|
||||||
for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
|
for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
|
||||||
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_idx);
|
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_idx);
|
||||||
auto origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
|
auto origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
|
||||||
auto origin_shape_base_ptr = common::AnfAlgo::GetPrevNodeOutputDetailShape(cnode, input_idx);
|
auto origin_shape_base_ptr = AnfAlgo::GetPrevNodeOutputDetailShape(cnode, input_idx);
|
||||||
auto origin_shape_ptr = origin_shape_base_ptr->cast<abstract::ShapePtr>();
|
auto origin_shape_ptr = origin_shape_base_ptr->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(origin_shape_ptr);
|
MS_EXCEPTION_IF_NULL(origin_shape_ptr);
|
||||||
auto origin_shape = origin_shape_ptr->shape();
|
auto origin_shape = origin_shape_ptr->shape();
|
||||||
|
|
|
@ -54,7 +54,7 @@ CNodePtr InsertForInput(const FuncGraphPtr &func_graph, const CNodePtr &node, co
|
||||||
|
|
||||||
auto in_node = common::AnfAlgo::GetInputNode(node, 0);
|
auto in_node = common::AnfAlgo::GetInputNode(node, 0);
|
||||||
auto type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
auto type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||||
auto in_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
|
auto in_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
|
||||||
auto transpose_out_shape = InferTransposeOutputShape(in_shape, perm);
|
auto transpose_out_shape = InferTransposeOutputShape(in_shape, perm);
|
||||||
|
|
||||||
auto ori_out_types = AnfAlgo::GetAllOutputInferDataTypes(node);
|
auto ori_out_types = AnfAlgo::GetAllOutputInferDataTypes(node);
|
||||||
|
@ -112,7 +112,7 @@ AnfNodePtr InsertForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_
|
||||||
(void)transpose_inputs.push_back(tuple_getitem);
|
(void)transpose_inputs.push_back(tuple_getitem);
|
||||||
(void)transpose_inputs.push_back(perm_value_input);
|
(void)transpose_inputs.push_back(perm_value_input);
|
||||||
|
|
||||||
auto shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx);
|
auto shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
|
||||||
auto type = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
|
auto type = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
|
||||||
auto transpose_out_shape = InferTransposeOutputShape(shape, perm);
|
auto transpose_out_shape = InferTransposeOutputShape(shape, perm);
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ AnfNodePtr DealRefOutput::AddAdditionalToRefOutput(const FuncGraphPtr &func_grap
|
||||||
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
|
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
|
||||||
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
|
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
|
||||||
auto cur_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_index);
|
auto cur_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||||
auto detail_shape = common::AnfAlgo::GetOutputDetailShape(cnode, output_index);
|
auto detail_shape = AnfAlgo::GetOutputDetailShape(cnode, output_index);
|
||||||
// insert trans
|
// insert trans
|
||||||
if (origin_format != cur_format && cur_shape.size() > 1) {
|
if (origin_format != cur_format && cur_shape.size() > 1) {
|
||||||
auto kernel_select = std::make_shared<KernelSelect>();
|
auto kernel_select = std::make_shared<KernelSelect>();
|
||||||
|
|
|
@ -35,7 +35,7 @@ bool IsDepthwiseCase(const AnfNodePtr &node, size_t index, const std::string &fo
|
||||||
if (format != kOpFormat_FRAC_Z) {
|
if (format != kOpFormat_FRAC_Z) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
abstract::BaseShapePtr base_shape = common::AnfAlgo::GetOutputDetailShape(node, index);
|
abstract::BaseShapePtr base_shape = AnfAlgo::GetOutputDetailShape(node, index);
|
||||||
MS_EXCEPTION_IF_NULL(base_shape);
|
MS_EXCEPTION_IF_NULL(base_shape);
|
||||||
if (base_shape->isa<abstract::Shape>()) {
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
||||||
|
|
|
@ -51,7 +51,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
||||||
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
||||||
AnfNodePtr replace_node = nullptr;
|
AnfNodePtr replace_node = nullptr;
|
||||||
const auto origin_shape = common::AnfAlgo::GetOutputDetailShape(cnode, output_idx);
|
const auto origin_shape = AnfAlgo::GetOutputDetailShape(cnode, output_idx);
|
||||||
const auto origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
const auto origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||||
auto idx = NewValueNode(SizeToLong(output_idx));
|
auto idx = NewValueNode(SizeToLong(output_idx));
|
||||||
MS_EXCEPTION_IF_NULL(idx);
|
MS_EXCEPTION_IF_NULL(idx);
|
||||||
|
@ -105,7 +105,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &o
|
||||||
// Single output, output is not TUPLE
|
// Single output, output is not TUPLE
|
||||||
if (!cnode->Type()->isa<Tuple>()) {
|
if (!cnode->Type()->isa<Tuple>()) {
|
||||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||||
const abstract::BaseShapePtr origin_shape = common::AnfAlgo::GetOutputDetailShape(cnode, 0);
|
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(cnode, 0);
|
||||||
const TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
|
const TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
||||||
AnfNodePtr replace_node = cnode;
|
AnfNodePtr replace_node = cnode;
|
||||||
|
|
|
@ -31,8 +31,8 @@ bool IsDepthwiseCase(const CNodePtr &node, size_t index, const std::string &form
|
||||||
if (format != kOpFormat_FRAC_Z) {
|
if (format != kOpFormat_FRAC_Z) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
abstract::BaseShapePtr base_shape = is_tuple ? common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index)
|
abstract::BaseShapePtr base_shape =
|
||||||
: common::AnfAlgo::GetOutputDetailShape(node, index);
|
is_tuple ? AnfAlgo::GetPrevNodeOutputDetailShape(node, index) : AnfAlgo::GetOutputDetailShape(node, index);
|
||||||
MS_EXCEPTION_IF_NULL(base_shape);
|
MS_EXCEPTION_IF_NULL(base_shape);
|
||||||
if (base_shape->isa<abstract::Shape>()) {
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
||||||
|
|
|
@ -125,7 +125,7 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast, 0);
|
auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast, 0);
|
||||||
auto cast_shape = common::AnfAlgo::GetOutputDetailShape(cast, 0);
|
auto cast_shape = AnfAlgo::GetOutputDetailShape(cast, 0);
|
||||||
std::vector<abstract::BaseShapePtr> shapes;
|
std::vector<abstract::BaseShapePtr> shapes;
|
||||||
std::vector<TypeId> types;
|
std::vector<TypeId> types;
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
|
@ -135,7 +135,7 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
|
||||||
(void)types.emplace_back(cast_dtype);
|
(void)types.emplace_back(cast_dtype);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
(void)shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(cnode, index));
|
(void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(cnode, index));
|
||||||
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, index));
|
(void)types.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, index));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "backend/common/optimizer/optimizer.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -47,8 +48,7 @@ AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, co
|
||||||
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
||||||
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)},
|
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
|
||||||
cast.get());
|
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast);
|
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast);
|
||||||
cast->set_scope(input->scope());
|
cast->set_scope(input->scope());
|
||||||
return cast;
|
return cast;
|
||||||
|
|
|
@ -246,7 +246,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
||||||
|
|
||||||
tile_node->set_scope(mul_node->scope());
|
tile_node->set_scope(mul_node->scope());
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)},
|
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)},
|
||||||
{common::AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1)},
|
{AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1)},
|
||||||
tile_node.get());
|
tile_node.get());
|
||||||
// Feature map set
|
// Feature map set
|
||||||
std::vector<size_t> feature_map_input_indexs;
|
std::vector<size_t> feature_map_input_indexs;
|
||||||
|
|
|
@ -44,8 +44,7 @@ void BatchNormGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, co
|
||||||
|
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
|
||||||
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
|
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
|
||||||
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
|
||||||
|
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
|
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
|
||||||
|
@ -79,7 +78,7 @@ void BatchNormGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, co
|
||||||
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
||||||
|
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
|
||||||
|
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
|
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
|
||||||
|
|
|
@ -41,7 +41,7 @@ AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_grap
|
||||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||||
auto predict_input = cnode->inputs()[kIndex1];
|
auto predict_input = cnode->inputs()[kIndex1];
|
||||||
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||||
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)};
|
auto new_node_shape = {AnfAlgo::GetOutputDetailShape(predict_input, 0)};
|
||||||
// The kAttrReduction is necessary for InferShape of BCEWithLogitsLoss op
|
// The kAttrReduction is necessary for InferShape of BCEWithLogitsLoss op
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrReduction, MakeValue("none"), new_cnode);
|
common::AnfAlgo::SetNodeAttr(kAttrReduction, MakeValue("none"), new_cnode);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
|
||||||
|
@ -60,7 +60,7 @@ AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_grap
|
||||||
}
|
}
|
||||||
auto reduce_node = NewCNode(reduce_inputs, func_graph);
|
auto reduce_node = NewCNode(reduce_inputs, func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||||
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shape = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||||
if (type == kNumberTypeFloat16) {
|
if (type == kNumberTypeFloat16) {
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({kNumberTypeFloat32}, shape, reduce_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape({kNumberTypeFloat32}, shape, reduce_node.get());
|
||||||
|
|
|
@ -45,8 +45,7 @@ void BnGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNo
|
||||||
|
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
|
||||||
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
|
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 1), AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
|
||||||
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2)};
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_update_grad.get());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
|
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad);
|
||||||
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {
|
||||||
|
@ -86,7 +85,7 @@ void BnGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNo
|
||||||
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
||||||
|
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_reduce_grad.get());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
|
common::AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
|
||||||
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrFormat, bn_grad_node)) {
|
||||||
|
|
|
@ -56,8 +56,7 @@ bool BnSplit::CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const C
|
||||||
bn_training_reduce->set_kernel_info(kernel_info);
|
bn_training_reduce->set_kernel_info(kernel_info);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1),
|
||||||
common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1)};
|
common::AnfAlgo::GetOutputInferDataType(bn_cnode, 1)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_cnode, 1),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_cnode, 1), AnfAlgo::GetOutputDetailShape(bn_cnode, 1)};
|
||||||
common::AnfAlgo::GetOutputDetailShape(bn_cnode, 1)};
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_training_reduce.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, bn_training_reduce.get());
|
||||||
bn_training_reduce->set_scope(bn_cnode->scope());
|
bn_training_reduce->set_scope(bn_cnode->scope());
|
||||||
if (is_dynamic) {
|
if (is_dynamic) {
|
||||||
|
@ -205,8 +204,7 @@ AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
||||||
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)},
|
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
|
||||||
cast.get());
|
|
||||||
common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||||
cast->set_scope(input->scope());
|
cast->set_scope(input->scope());
|
||||||
return cast;
|
return cast;
|
||||||
|
|
|
@ -50,7 +50,7 @@ AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const
|
||||||
if (axis_from_attr < 0) {
|
if (axis_from_attr < 0) {
|
||||||
axis_from_attr += SizeToLong(input_shape.size());
|
axis_from_attr += SizeToLong(input_shape.size());
|
||||||
}
|
}
|
||||||
auto output_shape_ptr = common::AnfAlgo::GetOutputDetailShape(origin_concat_cnode, 0);
|
auto output_shape_ptr = AnfAlgo::GetOutputDetailShape(origin_concat_cnode, 0);
|
||||||
MS_EXCEPTION_IF_NULL(output_shape_ptr);
|
MS_EXCEPTION_IF_NULL(output_shape_ptr);
|
||||||
auto output_shapeptr = output_shape_ptr->cast<abstract::ShapePtr>();
|
auto output_shapeptr = output_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(output_shapeptr);
|
MS_EXCEPTION_IF_NULL(output_shapeptr);
|
||||||
|
@ -63,7 +63,7 @@ AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const
|
||||||
auto axis = LongToSize(axis_from_attr);
|
auto axis = LongToSize(axis_from_attr);
|
||||||
output_shape[axis] = 0;
|
output_shape[axis] = 0;
|
||||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||||
auto last_input_shape_ptr = common::AnfAlgo::GetPrevNodeOutputDetailShape(origin_concat_cnode, i - 1);
|
auto last_input_shape_ptr = AnfAlgo::GetPrevNodeOutputDetailShape(origin_concat_cnode, i - 1);
|
||||||
MS_EXCEPTION_IF_NULL(last_input_shape_ptr);
|
MS_EXCEPTION_IF_NULL(last_input_shape_ptr);
|
||||||
auto last_input_shapeptr = last_input_shape_ptr->cast<abstract::ShapePtr>();
|
auto last_input_shapeptr = last_input_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(last_input_shapeptr);
|
MS_EXCEPTION_IF_NULL(last_input_shapeptr);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "backend/common/optimizer/optimizer.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -46,8 +47,7 @@ AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, co
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
||||||
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {common::AnfAlgo::GetOutputDetailShape(input, 0)},
|
common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get());
|
||||||
cast.get());
|
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), cast);
|
common::AnfAlgo::SetNodeAttr(kAttrDstType, MakeValue(static_cast<size_t>(dst_type)), cast);
|
||||||
cast->set_scope(input->scope());
|
cast->set_scope(input->scope());
|
||||||
return cast;
|
return cast;
|
||||||
|
|
|
@ -48,8 +48,8 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackpropV2(const FuncGraphPtr
|
||||||
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
|
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
|
||||||
layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
|
layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0),
|
||||||
common::AnfAlgo::GetPrevNodeOutputDetailShape(layer_norm_grad, 1)};
|
AnfAlgo::GetPrevNodeOutputDetailShape(layer_norm_grad, 1)};
|
||||||
if (is_dynamic) {
|
if (is_dynamic) {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
|
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
|
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), layer_norm_x_backprop);
|
||||||
|
@ -78,8 +78,8 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackpropV2(
|
||||||
}
|
}
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
||||||
common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
common::AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
||||||
common::AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get());
|
||||||
|
|
||||||
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
|
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
|
||||||
|
|
|
@ -40,7 +40,7 @@ AnfNodePtr PackFission::CreateNewPack(const FuncGraphPtr &func_graph, const CNod
|
||||||
std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)};
|
std::vector<int64_t> dyn_input_sizes{SizeToLong(offset)};
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack);
|
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack);
|
||||||
// infer shape
|
// infer shape
|
||||||
auto output_shape_ptr = common::AnfAlgo::GetOutputDetailShape(origin_pack_cnode, 0);
|
auto output_shape_ptr = AnfAlgo::GetOutputDetailShape(origin_pack_cnode, 0);
|
||||||
MS_EXCEPTION_IF_NULL(output_shape_ptr);
|
MS_EXCEPTION_IF_NULL(output_shape_ptr);
|
||||||
auto output_shape = output_shape_ptr->cast<abstract::ShapePtr>();
|
auto output_shape = output_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(output_shape);
|
MS_EXCEPTION_IF_NULL(output_shape);
|
||||||
|
|
|
@ -54,7 +54,7 @@ const AnfNodePtr ReduceSumFission::Process(const FuncGraphPtr &graph, const AnfN
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||||
auto keep_dims = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims);
|
auto keep_dims = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims);
|
||||||
auto out_shape = common::AnfAlgo::GetOutputDetailShape(cnode, 0);
|
auto out_shape = AnfAlgo::GetOutputDetailShape(cnode, 0);
|
||||||
std::vector<int64_t> inp_axis;
|
std::vector<int64_t> inp_axis;
|
||||||
auto axis_value = prim->GetAttr(kAttrAxis);
|
auto axis_value = prim->GetAttr(kAttrAxis);
|
||||||
MS_EXCEPTION_IF_NULL(axis_value);
|
MS_EXCEPTION_IF_NULL(axis_value);
|
||||||
|
|
|
@ -32,9 +32,9 @@ bool IsDepthwiseCase(const CNodePtr &node, const std::string &input_format, cons
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
abstract::BaseShapePtr base_shape;
|
abstract::BaseShapePtr base_shape;
|
||||||
if (input_format == kOpFormat_FRAC_Z && output_format == kOpFormat_DEFAULT) {
|
if (input_format == kOpFormat_FRAC_Z && output_format == kOpFormat_DEFAULT) {
|
||||||
base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
|
base_shape = AnfAlgo::GetPrevNodeOutputDetailShape(node, 0);
|
||||||
} else if (input_format == kOpFormat_DEFAULT && output_format == kOpFormat_FRAC_Z) {
|
} else if (input_format == kOpFormat_DEFAULT && output_format == kOpFormat_FRAC_Z) {
|
||||||
base_shape = common::AnfAlgo::GetOutputDetailShape(node, 0);
|
base_shape = AnfAlgo::GetOutputDetailShape(node, 0);
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -330,8 +330,8 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
|
||||||
MS_EXCEPTION_IF_NULL(add1);
|
MS_EXCEPTION_IF_NULL(add1);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
|
||||||
common::AnfAlgo::GetOutputInferDataType(sub0, 0)};
|
common::AnfAlgo::GetOutputInferDataType(sub0, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(add1, 0), common::AnfAlgo::GetOutputDetailShape(add0, 0),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(add1, 0), AnfAlgo::GetOutputDetailShape(add0, 0),
|
||||||
common::AnfAlgo::GetOutputDetailShape(sub0, 0)};
|
AnfAlgo::GetOutputDetailShape(sub0, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
||||||
|
|
||||||
std::vector<AnfNodePtr> fusion_node_outputs;
|
std::vector<AnfNodePtr> fusion_node_outputs;
|
||||||
|
|
|
@ -141,8 +141,8 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
|
||||||
|
|
||||||
if (height % output_h != 0 || width % output_w != 0) {
|
if (height % output_h != 0 || width % output_w != 0) {
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0), kNumberTypeInt64};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0), kNumberTypeInt64};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0),
|
||||||
common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
|
AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adaptive_max_pool2d.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adaptive_max_pool2d.get());
|
||||||
std::vector<AnfNodePtr> multi_outputs;
|
std::vector<AnfNodePtr> multi_outputs;
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, adaptive_max_pool2d, kAdaptiveMaxpool2DOutputNumber, &multi_outputs);
|
CreateMultipleOutputsOfAnfNode(func_graph, adaptive_max_pool2d, kAdaptiveMaxpool2DOutputNumber, &multi_outputs);
|
||||||
|
@ -159,7 +159,7 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
|
||||||
adaptive_max_pool2d->inputs().end());
|
adaptive_max_pool2d->inputs().end());
|
||||||
auto pooling = NewCNode(pooling_inputs, kernel_graph);
|
auto pooling = NewCNode(pooling_inputs, kernel_graph);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(adaptive_max_pool2d, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(adaptive_max_pool2d, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, pooling.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, pooling.get());
|
||||||
pooling->set_scope(adaptive_max_pool2d->scope());
|
pooling->set_scope(adaptive_max_pool2d->scope());
|
||||||
SetNodeAttr(pooling, height_attr, width_attr);
|
SetNodeAttr(pooling, height_attr, width_attr);
|
||||||
|
|
|
@ -95,8 +95,7 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
|
||||||
MS_EXCEPTION_IF_NULL(fused_dbn_dw);
|
MS_EXCEPTION_IF_NULL(fused_dbn_dw);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0),
|
||||||
common::AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)};
|
common::AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bnreduce_grad, 0),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bnreduce_grad, 0), AnfAlgo::GetOutputDetailShape(conv_back_filter, 0)};
|
||||||
common::AnfAlgo::GetOutputDetailShape(conv_back_filter, 0)};
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_dbn_dw.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_dbn_dw.get());
|
||||||
fused_dbn_dw->set_scope(bnreduce_grad->scope());
|
fused_dbn_dw->set_scope(bnreduce_grad->scope());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrFilterSizes, conv_back_filter, fused_dbn_dw);
|
common::AnfAlgo::CopyNodeAttr(kAttrFilterSizes, conv_back_filter, fused_dbn_dw);
|
||||||
|
|
|
@ -65,7 +65,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
|
||||||
auto fusion_node = NewCNode(inputs, graph);
|
auto fusion_node = NewCNode(inputs, graph);
|
||||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
||||||
fusion_node->set_scope(node->scope());
|
fusion_node->set_scope(node->scope());
|
||||||
return fusion_node;
|
return fusion_node;
|
||||||
|
|
|
@ -91,7 +91,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
|
||||||
auto clip_by_value = NewCNode(inputs, graph);
|
auto clip_by_value = NewCNode(inputs, graph);
|
||||||
MS_EXCEPTION_IF_NULL(clip_by_value);
|
MS_EXCEPTION_IF_NULL(clip_by_value);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, clip_by_value.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, clip_by_value.get());
|
||||||
clip_by_value->set_scope(node->scope());
|
clip_by_value->set_scope(node->scope());
|
||||||
return clip_by_value;
|
return clip_by_value;
|
||||||
|
|
|
@ -113,7 +113,7 @@ CNodePtr ConfusionMulGradFusion::CreateFusionNode(const FuncGraphPtr &graph, con
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
|
common::AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(mul0, 0),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(mul0, 0),
|
||||||
common::AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
|
common::AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(mul0, 0), common::AnfAlgo::GetOutputDetailShape(reduce_sum, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(mul0, 0), AnfAlgo::GetOutputDetailShape(reduce_sum, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
||||||
return fusion_node;
|
return fusion_node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -183,8 +183,8 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
|
||||||
std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1);
|
std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0), common::AnfAlgo::GetOutputInferDataType(add0, 0),
|
||||||
common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add5, 0)};
|
common::AnfAlgo::GetOutputInferDataType(add1, 0), common::AnfAlgo::GetOutputInferDataType(add5, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0), common::AnfAlgo::GetOutputDetailShape(add0, 0),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0), AnfAlgo::GetOutputDetailShape(add0, 0),
|
||||||
common::AnfAlgo::GetOutputDetailShape(add1, 0), common::AnfAlgo::GetOutputDetailShape(add5, 0)};
|
AnfAlgo::GetOutputDetailShape(add1, 0), AnfAlgo::GetOutputDetailShape(add5, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fusion_node.get());
|
||||||
|
|
||||||
std::vector<AnfNodePtr> fusion_node_outputs;
|
std::vector<AnfNodePtr> fusion_node_outputs;
|
||||||
|
|
|
@ -70,7 +70,7 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
|
||||||
MS_EXCEPTION_IF_NULL(lamb_update_with_lr);
|
MS_EXCEPTION_IF_NULL(lamb_update_with_lr);
|
||||||
|
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, lamb_update_with_lr.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, lamb_update_with_lr.get());
|
||||||
lamb_update_with_lr->set_scope(node->scope());
|
lamb_update_with_lr->set_scope(node->scope());
|
||||||
return lamb_update_with_lr;
|
return lamb_update_with_lr;
|
||||||
|
|
|
@ -83,7 +83,7 @@ const AnfNodePtr SoftmaxDropoutDoMaskV3Fusion::Process(const FuncGraphPtr &graph
|
||||||
MS_EXCEPTION_IF_NULL(softmax_dropout);
|
MS_EXCEPTION_IF_NULL(softmax_dropout);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(softmax, 0),
|
auto types = {common::AnfAlgo::GetOutputInferDataType(softmax, 0),
|
||||||
common::AnfAlgo::GetOutputInferDataType(dropout, 0)};
|
common::AnfAlgo::GetOutputInferDataType(dropout, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(softmax, 0), common::AnfAlgo::GetOutputDetailShape(dropout, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(softmax, 0), AnfAlgo::GetOutputDetailShape(dropout, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, softmax_dropout.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, softmax_dropout.get());
|
||||||
softmax_dropout->set_scope(softmax->scope());
|
softmax_dropout->set_scope(softmax->scope());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrAxis, softmax, softmax_dropout);
|
common::AnfAlgo::CopyNodeAttr(kAttrAxis, softmax, softmax_dropout);
|
||||||
|
|
|
@ -60,7 +60,7 @@ CNodePtr SquareSumFusion::GenerateSquareSumV1(const FuncGraphPtr &graph, const C
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
square_sumv1->set_kernel_info(kernel_info);
|
square_sumv1->set_kernel_info(kernel_info);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(sum, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(sum, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv1.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv1.get());
|
||||||
square_sumv1->set_scope(sum->scope());
|
square_sumv1->set_scope(sum->scope());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1);
|
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1);
|
||||||
|
@ -82,7 +82,7 @@ CNodePtr SquareSumFusion::GenerateSquareSumV2(const FuncGraphPtr &graph, const C
|
||||||
auto square_sumv2 = NewCNode(square_sumv2_inputs, graph);
|
auto square_sumv2 = NewCNode(square_sumv2_inputs, graph);
|
||||||
MS_EXCEPTION_IF_NULL(square_sumv2);
|
MS_EXCEPTION_IF_NULL(square_sumv2);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0), common::AnfAlgo::GetOutputInferDataType(square, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(sum, 0), common::AnfAlgo::GetOutputInferDataType(square, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(sum, 0), common::AnfAlgo::GetOutputDetailShape(square, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(sum, 0), AnfAlgo::GetOutputDetailShape(square, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv2.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, square_sumv2.get());
|
||||||
square_sumv2->set_scope(sum->scope());
|
square_sumv2->set_scope(sum->scope());
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2);
|
common::AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2);
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "include/common/utils/comm_manager.h"
|
#include "include/common/utils/comm_manager.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -117,7 +118,7 @@ CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph, con
|
||||||
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
|
||||||
auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
|
auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
|
||||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||||
auto single_shape = common::AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
|
auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0UL);
|
||||||
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
|
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0UL);
|
||||||
std::vector<TypeId> dtypes(split_count, single_type);
|
std::vector<TypeId> dtypes(split_count, single_type);
|
||||||
std::vector<BaseShapePtr> shapes(split_count, single_shape);
|
std::vector<BaseShapePtr> shapes(split_count, single_shape);
|
||||||
|
|
|
@ -49,11 +49,10 @@ AnfNodePtr BuildBatchNormGrad(const PatternMap &m, const AnfNodePtr &new_node) {
|
||||||
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2UL),
|
common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 2UL),
|
||||||
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 3UL),
|
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 3UL),
|
||||||
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 4UL)};
|
common::AnfAlgo::GetPrevNodeOutputInferDataType(bn_grad_node, 4UL)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 0UL),
|
auto shapes = {AnfAlgo::GetOutputDetailShape(bn_grad_node, 0UL), AnfAlgo::GetOutputDetailShape(bn_grad_node, 1UL),
|
||||||
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 1UL),
|
AnfAlgo::GetOutputDetailShape(bn_grad_node, 2UL),
|
||||||
common::AnfAlgo::GetOutputDetailShape(bn_grad_node, 2UL),
|
AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 3UL),
|
||||||
common::AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 3UL),
|
AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 4UL)};
|
||||||
common::AnfAlgo::GetPrevNodeOutputDetailShape(bn_grad_node, 4UL)};
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_bn_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_bn_grad.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(bn_grad_node, new_bn_grad);
|
common::AnfAlgo::CopyNodeAttrs(bn_grad_node, new_bn_grad);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad);
|
common::AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad);
|
||||||
|
|
|
@ -64,7 +64,7 @@ CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolWithArgmax(const FuncGraphPtr &
|
||||||
// MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output
|
// MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output
|
||||||
TypeId argmax_dtype = kNumberTypeUInt16;
|
TypeId argmax_dtype = kNumberTypeUInt16;
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(maxpool, 0UL), argmax_dtype};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(maxpool, 0UL), argmax_dtype};
|
||||||
auto out_shape = common::AnfAlgo::GetOutputDetailShape(maxpool, 0UL);
|
auto out_shape = AnfAlgo::GetOutputDetailShape(maxpool, 0UL);
|
||||||
std::vector<BaseShapePtr> shapes = {out_shape, out_shape};
|
std::vector<BaseShapePtr> shapes = {out_shape, out_shape};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, maxpool_argmax.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, maxpool_argmax.get());
|
||||||
return maxpool_argmax;
|
return maxpool_argmax;
|
||||||
|
|
|
@ -835,7 +835,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
|
||||||
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
|
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
|
||||||
auto centerx_dtype = common::AnfAlgo::GetOutputInferDataType(centerx, 0UL);
|
auto centerx_dtype = common::AnfAlgo::GetOutputInferDataType(centerx, 0UL);
|
||||||
auto centerx_shape = common::AnfAlgo::GetOutputInferShape(centerx, 0UL);
|
auto centerx_shape = common::AnfAlgo::GetOutputInferShape(centerx, 0UL);
|
||||||
auto base_shape = common::AnfAlgo::GetOutputDetailShape(centerx, 0UL);
|
auto base_shape = AnfAlgo::GetOutputDetailShape(centerx, 0UL);
|
||||||
// empty
|
// empty
|
||||||
int64_t all_to_all_output_num =
|
int64_t all_to_all_output_num =
|
||||||
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
||||||
|
|
|
@ -294,9 +294,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
||||||
auto tile_node = pass.NewCNode(tile_inputs, graph);
|
auto tile_node = pass.NewCNode(tile_inputs, graph);
|
||||||
MS_EXCEPTION_IF_NULL(tile_node);
|
MS_EXCEPTION_IF_NULL(tile_node);
|
||||||
tile_node->set_scope(mul_node->scope());
|
tile_node->set_scope(mul_node->scope());
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(
|
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)},
|
||||||
{common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)},
|
{AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1UL)},
|
||||||
{common::AnfAlgo::GetPrevNodeOutputDetailShape(sparse_softmax_node, 1UL)}, tile_node.get());
|
tile_node.get());
|
||||||
if (is_convert_const_to_attr) {
|
if (is_convert_const_to_attr) {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
|
common::AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,8 +134,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_index);
|
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_index);
|
||||||
MS_EXCEPTION_IF_NULL(cur_input);
|
MS_EXCEPTION_IF_NULL(cur_input);
|
||||||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||||
const abstract::BaseShapePtr origin_shape =
|
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
||||||
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
|
||||||
TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
||||||
if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) {
|
if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) {
|
||||||
auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape);
|
auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape);
|
||||||
|
@ -199,7 +198,7 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
||||||
auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(func_output, i);
|
auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(func_output, i);
|
||||||
const std::string dev_fmt = AnfAlgo::GetPrevNodeOutputFormat(func_output, i);
|
const std::string dev_fmt = AnfAlgo::GetPrevNodeOutputFormat(func_output, i);
|
||||||
if (infer_type != device_type && device_type != kTypeUnknown) {
|
if (infer_type != device_type && device_type != kTypeUnknown) {
|
||||||
const abstract::BaseShapePtr origin_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(func_output_node, i);
|
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetPrevNodeOutputDetailShape(func_output_node, i);
|
||||||
auto cast = AddCastOpNodeToGraph(func_graph, input_node, dev_fmt, device_type, infer_type, origin_shape);
|
auto cast = AddCastOpNodeToGraph(func_graph, input_node, dev_fmt, device_type, infer_type, origin_shape);
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
cast->set_scope(func_output->scope());
|
cast->set_scope(func_output->scope());
|
||||||
|
|
|
@ -86,7 +86,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
||||||
auto transpose_op = graph->NewCNode(transpose_input);
|
auto transpose_op = graph->NewCNode(transpose_input);
|
||||||
// 3.Set the output info of transpose.
|
// 3.Set the output info of transpose.
|
||||||
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))};
|
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))};
|
||||||
auto transpose_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))};
|
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get());
|
||||||
// 4. Set the new edge of transpose op.
|
// 4. Set the new edge of transpose op.
|
||||||
FuncGraphManagerPtr manager = graph->manager();
|
FuncGraphManagerPtr manager = graph->manager();
|
||||||
|
|
|
@ -131,7 +131,7 @@ bool PrintValueType::Run(const FuncGraphPtr &graph) {
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
|
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
|
||||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(cnode, i));
|
shapes.push_back(AnfAlgo::GetOutputDetailShape(cnode, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
||||||
// add build info
|
// add build info
|
||||||
|
|
|
@ -165,7 +165,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
|
||||||
auto adam = graph->NewCNode(inputs);
|
auto adam = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(adam);
|
MS_EXCEPTION_IF_NULL(adam);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam.get());
|
||||||
adam->set_scope(node->scope());
|
adam->set_scope(node->scope());
|
||||||
auto build_info = GenerateKernelBuildInfo(adam);
|
auto build_info = GenerateKernelBuildInfo(adam);
|
||||||
|
|
|
@ -170,7 +170,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
||||||
auto adam_weight_decay = graph->NewCNode(inputs);
|
auto adam_weight_decay = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam_weight_decay.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, adam_weight_decay.get());
|
||||||
adam_weight_decay->set_scope(node->scope());
|
adam_weight_decay->set_scope(node->scope());
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
|
||||||
auto add_relugrad = graph->NewCNode(inputs);
|
auto add_relugrad = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(add_relugrad);
|
MS_EXCEPTION_IF_NULL(add_relugrad);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relugrad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relugrad.get());
|
||||||
add_relugrad->set_scope(node->scope());
|
add_relugrad->set_scope(node->scope());
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
|
||||||
size_t output_num = AnfAlgo::GetOutputElementNum(node);
|
size_t output_num = AnfAlgo::GetOutputElementNum(node);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
||||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(node, i));
|
shapes.push_back(AnfAlgo::GetOutputDetailShape(node, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, add_relu.get());
|
||||||
|
|
|
@ -100,7 +100,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
|
||||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||||
|
|
||||||
// Prepare dtypes, shapes and ranks vectors.
|
// Prepare dtypes, shapes and ranks vectors.
|
||||||
auto single_shape = common::AnfAlgo::GetOutputDetailShape(split_outputs[0], 0);
|
auto single_shape = AnfAlgo::GetOutputDetailShape(split_outputs[0], 0);
|
||||||
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
|
auto single_type = common::AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
|
||||||
std::vector<TypeId> dtypes(split_count, single_type);
|
std::vector<TypeId> dtypes(split_count, single_type);
|
||||||
std::vector<BaseShapePtr> shapes(split_count, single_shape);
|
std::vector<BaseShapePtr> shapes(split_count, single_shape);
|
||||||
|
|
|
@ -89,7 +89,7 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co
|
||||||
auto replace_node = graph->NewCNode(inputs);
|
auto replace_node = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
||||||
replace_node->set_scope(node->scope());
|
replace_node->set_scope(node->scope());
|
||||||
return replace_node;
|
return replace_node;
|
||||||
|
|
|
@ -61,7 +61,7 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra
|
||||||
auto replace_node = graph->NewCNode(inputs);
|
auto replace_node = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
||||||
replace_node->set_scope(node->scope());
|
replace_node->set_scope(node->scope());
|
||||||
return replace_node;
|
return replace_node;
|
||||||
|
|
|
@ -127,7 +127,7 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
|
||||||
auto replace_node = graph->NewCNode(inputs);
|
auto replace_node = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, replace_node.get());
|
||||||
replace_node->set_scope(node->scope());
|
replace_node->set_scope(node->scope());
|
||||||
return replace_node;
|
return replace_node;
|
||||||
|
|
|
@ -149,7 +149,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||||
|
|
|
@ -74,11 +74,11 @@ void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
|
auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(bn_grad, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(bn_grad, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(bn_grad, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(bn_grad, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, 0));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, 0));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(relu_grad, 0));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(relu_grad, 0));
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu);
|
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu);
|
||||||
|
|
|
@ -101,7 +101,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(node, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(node, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
|
common::AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
|
||||||
|
|
|
@ -41,7 +41,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto predict_input = cnode->inputs()[1];
|
auto predict_input = cnode->inputs()[1];
|
||||||
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||||
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)};
|
auto new_node_shape = {AnfAlgo::GetOutputDetailShape(predict_input, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(new_node_dtype, new_node_shape, new_cnode.get());
|
||||||
|
|
||||||
// Add reduce node
|
// Add reduce node
|
||||||
|
@ -69,7 +69,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
||||||
auto reduce_node = func_graph->NewCNode(reduce_inputs);
|
auto reduce_node = func_graph->NewCNode(reduce_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||||
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||||
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shape = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get());
|
||||||
common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
|
common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
|
||||||
reduce_node->set_scope(cnode->scope());
|
reduce_node->set_scope(cnode->scope());
|
||||||
|
|
|
@ -155,7 +155,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
|
||||||
idx->set_abstract(abstract_scalar);
|
idx->set_abstract(abstract_scalar);
|
||||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
auto shape = common::AnfAlgo::GetOutputDetailShape(node, i);
|
auto shape = AnfAlgo::GetOutputDetailShape(node, i);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]}, {shape}, tuple_getitem.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape({std::get<0>(output_info)[i]}, {shape}, tuple_getitem.get());
|
||||||
new_outputs.emplace_back(std::move(tuple_getitem));
|
new_outputs.emplace_back(std::move(tuple_getitem));
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,7 +158,7 @@ void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
|
||||||
std::vector<BaseShapePtr> shapes;
|
std::vector<BaseShapePtr> shapes;
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
types.emplace_back(common::AnfAlgo::GetOutputInferDataType(src, i));
|
types.emplace_back(common::AnfAlgo::GetOutputInferDataType(src, i));
|
||||||
shapes.emplace_back(common::AnfAlgo::GetOutputDetailShape(src, i));
|
shapes.emplace_back(AnfAlgo::GetOutputDetailShape(src, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dst.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dst.get());
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ void InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, con
|
||||||
auto cast = graph->NewCNode(inputs);
|
auto cast = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast);
|
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast);
|
||||||
auto cast_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
|
auto cast_shape = {AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
|
||||||
FuncGraphManagerPtr manager = graph->manager();
|
FuncGraphManagerPtr manager = graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
@ -110,7 +110,7 @@ bool InsertCastGPU::Run(const FuncGraphPtr &graph) {
|
||||||
auto output_types = std::vector<TypeId>(output_num, kNumberTypeFloat32);
|
auto output_types = std::vector<TypeId>(output_num, kNumberTypeFloat32);
|
||||||
std::vector<BaseShapePtr> output_shapes;
|
std::vector<BaseShapePtr> output_shapes;
|
||||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||||
auto shape = common::AnfAlgo::GetOutputDetailShape(node, output_index);
|
auto shape = AnfAlgo::GetOutputDetailShape(node, output_index);
|
||||||
(void)output_shapes.emplace_back(shape);
|
(void)output_shapes.emplace_back(shape);
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(output_types, output_shapes, node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(output_types, output_shapes, node.get());
|
||||||
|
|
|
@ -130,7 +130,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
||||||
MS_EXCEPTION_IF_NULL(transpose_op);
|
MS_EXCEPTION_IF_NULL(transpose_op);
|
||||||
// 3.Set the output info of transpose.
|
// 3.Set the output info of transpose.
|
||||||
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
||||||
auto base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index);
|
auto base_shape = AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index);
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get());
|
||||||
|
|
||||||
// 4. Set the new edge of transpose op.
|
// 4. Set the new edge of transpose op.
|
||||||
|
|
|
@ -100,7 +100,7 @@ const AnfNodePtr MatMulBiasAddFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
|
|
||||||
// Copy Abstract and KernelBuildInfo.
|
// Copy Abstract and KernelBuildInfo.
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
auto shapes = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
auto shapes = {AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, fused_node.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(matmul, fused_node);
|
common::AnfAlgo::CopyNodeAttrs(matmul, fused_node);
|
||||||
fused_node->set_scope(node->scope());
|
fused_node->set_scope(node->scope());
|
||||||
|
|
|
@ -91,7 +91,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(batch_norm, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||||
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
common::AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||||
|
|
|
@ -191,7 +191,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
|
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
|
||||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(cnode, i));
|
shapes.push_back(AnfAlgo::GetOutputDetailShape(cnode, i));
|
||||||
}
|
}
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, print_fused.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, print_fused.get());
|
||||||
// add build info
|
// add build info
|
||||||
|
|
|
@ -83,7 +83,7 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||||
auto element_num = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
auto element_num = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||||
|
|
||||||
std::vector<int64_t> mask_shape = {(element_num + kBitPerUInt - 1) / kBitPerUInt};
|
std::vector<int64_t> mask_shape = {(element_num + kBitPerUInt - 1) / kBitPerUInt};
|
||||||
std::vector<BaseShapePtr> shapes = {common::AnfAlgo::GetOutputDetailShape(relu, 0),
|
std::vector<BaseShapePtr> shapes = {AnfAlgo::GetOutputDetailShape(relu, 0),
|
||||||
std::make_shared<abstract::Shape>(mask_shape)};
|
std::make_shared<abstract::Shape>(mask_shape)};
|
||||||
auto types = {common::AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
|
auto types = {common::AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());
|
||||||
|
@ -110,7 +110,7 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, i));
|
types.push_back(common::AnfAlgo::GetOutputInferDataType(relu_grad, i));
|
||||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(relu_grad, i));
|
shapes.push_back(AnfAlgo::GetOutputDetailShape(relu_grad, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, new_node.get());
|
||||||
|
|
|
@ -44,7 +44,7 @@ AnfNodePtr BuildAdd(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||||
std::vector<TypeId> outputs_type;
|
std::vector<TypeId> outputs_type;
|
||||||
std::vector<BaseShapePtr> outputs_shape;
|
std::vector<BaseShapePtr> outputs_shape;
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(m.Get(A), 0));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(m.Get(A), 0));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(m.Get(A), 0));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(m.Get(A), 0));
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, default_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(outputs_type, outputs_shape, default_node.get());
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(m.Get(m_addn)), default_node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(m.Get(m_addn)), default_node.get());
|
||||||
return default_node;
|
return default_node;
|
||||||
|
|
|
@ -51,7 +51,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
|
||||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, i));
|
||||||
outputs_shape.push_back(common::AnfAlgo::GetOutputDetailShape(node, i));
|
outputs_shape.push_back(AnfAlgo::GetOutputDetailShape(node, i));
|
||||||
}
|
}
|
||||||
outputs_type[kGradIndex] = common::AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
outputs_type[kGradIndex] = common::AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ void CopyGraphOutputTypeAndShape(const std::vector<session::KernelWithIndex> &gr
|
||||||
std::vector<BaseShapePtr> shapes;
|
std::vector<BaseShapePtr> shapes;
|
||||||
for (const auto &item : graph_outputs) {
|
for (const auto &item : graph_outputs) {
|
||||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(item.first, item.second));
|
types.push_back(common::AnfAlgo::GetOutputInferDataType(item.first, item.second));
|
||||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(item.first, item.second));
|
shapes.push_back(AnfAlgo::GetOutputDetailShape(item.first, item.second));
|
||||||
}
|
}
|
||||||
|
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, trt_node.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, trt_node.get());
|
||||||
|
|
|
@ -625,30 +625,6 @@ inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
|
||||||
return shape_ptr->shape();
|
return shape_ptr->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Get the element shape of dynamic sequence shape.
|
|
||||||
abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid dynamic shape in node:" << node->DebugString() << ".";
|
|
||||||
}
|
|
||||||
if (node->abstract() == nullptr) {
|
|
||||||
MS_LOG(EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
|
||||||
}
|
|
||||||
if (!node->abstract()->isa<abstract::AbstractSequence>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Not sequence abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
|
||||||
}
|
|
||||||
const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(sequence_abs);
|
|
||||||
if (!sequence_abs->dynamic_len()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
|
||||||
}
|
|
||||||
const auto &element_abs = sequence_abs->dynamic_len_element_abs();
|
|
||||||
MS_EXCEPTION_IF_NULL(element_abs);
|
|
||||||
return element_abs->BuildShape();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
|
ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
|
||||||
size_t output_idx, bool is_real_squence_output) {
|
size_t output_idx, bool is_real_squence_output) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -758,45 +734,6 @@ TypeId AnfAlgo::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t in
|
||||||
return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
|
return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::BaseShapePtr AnfAlgo::GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
auto base_shape = node->Shape();
|
|
||||||
MS_EXCEPTION_IF_NULL(base_shape);
|
|
||||||
if (base_shape->isa<abstract::Shape>()) {
|
|
||||||
if (output_idx == 0) {
|
|
||||||
return base_shape;
|
|
||||||
}
|
|
||||||
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
|
|
||||||
<< "]." << trace::DumpSourceLines(node);
|
|
||||||
} else if (base_shape->isa<abstract::TupleShape>()) {
|
|
||||||
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_shape);
|
|
||||||
if (output_idx >= tuple_shape->size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
|
|
||||||
<< " node:" << node->DebugString() << "." << trace::DumpSourceLines(node);
|
|
||||||
}
|
|
||||||
auto b_shp = (*tuple_shape)[output_idx];
|
|
||||||
if (b_shp->isa<abstract::Shape>() || b_shp->isa<abstract::NoShape>()) {
|
|
||||||
return b_shp;
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
|
|
||||||
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
|
|
||||||
<< "node :" << node->DebugString() << "." << trace::DumpSourceLines(node);
|
|
||||||
}
|
|
||||||
} else if (base_shape->isa<abstract::NoShape>()) {
|
|
||||||
return base_shape;
|
|
||||||
} else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
|
|
||||||
return GetDynamicSequenceShape(node, output_idx);
|
|
||||||
}
|
|
||||||
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
|
|
||||||
<< base_shape->ToString() << " node : " << node->DebugString() << trace::DumpSourceLines(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract::BaseShapePtr AnfAlgo::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
|
|
||||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
|
||||||
return AnfAlgo::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
// set infer shapes and types of anf node
|
// set infer shapes and types of anf node
|
||||||
void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
||||||
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
|
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
|
||||||
|
@ -1936,5 +1873,26 @@ bool AnfAlgo::IsReduceOp(const std::string &op_name) {
|
||||||
prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()};
|
prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()};
|
||||||
return reduce_op_type.find(op_name) != reduce_op_type.end();
|
return reduce_op_type.find(op_name) != reduce_op_type.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid dynamic shape in node:" << node->DebugString() << ".";
|
||||||
|
}
|
||||||
|
if (node->abstract() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
||||||
|
}
|
||||||
|
if (!node->abstract()->isa<abstract::AbstractSequence>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Not sequence abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
||||||
|
}
|
||||||
|
const auto &sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(sequence_abs);
|
||||||
|
if (!sequence_abs->dynamic_len()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
|
||||||
|
}
|
||||||
|
const auto &element_abs = sequence_abs->dynamic_len_element_abs();
|
||||||
|
MS_EXCEPTION_IF_NULL(element_abs);
|
||||||
|
return element_abs->BuildShape();
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -151,6 +151,7 @@ def init(backend_name=None):
|
||||||
raise RuntimeError("Parameter server and scheduler should use 'CPU' as backend instead of 'Ascend'")
|
raise RuntimeError("Parameter server and scheduler should use 'CPU' as backend instead of 'Ascend'")
|
||||||
if _get_ps_context("worker_num") == 1:
|
if _get_ps_context("worker_num") == 1:
|
||||||
GlobalComm.INITED = True
|
GlobalComm.INITED = True
|
||||||
|
_set_elegant_exit_handle()
|
||||||
return
|
return
|
||||||
if device_target != "Ascend":
|
if device_target != "Ascend":
|
||||||
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
|
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
|
||||||
|
|
Loading…
Reference in New Issue