Multiplex op infer

This commit is contained in:
ZPaC 2023-01-04 15:58:46 +08:00
parent f484e029a1
commit bde4255041
3 changed files with 41 additions and 41 deletions

View File

@ -21,8 +21,8 @@
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "kernel/common_utils.h"
#include "include/common/utils/convert_utils.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace opt {
@ -240,6 +240,23 @@ void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
}
abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto found = abstract::GetPrimitiveInferImpl(primitive);
if (!found.has_value()) {
MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
}
std::vector<AbstractBasePtr> input_args;
std::for_each(input_list.begin(), input_list.end(),
[&input_args](const auto &input) { input_args.emplace_back(input->abstract()); });
auto infer_impl = found.value();
auto abs = infer_impl.InferShapeAndType(nullptr, primitive, input_args);
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
return abs;
}
abstract::AbstractBasePtr GenerateAbsByUserNodeInput(const CNodePtr &user_node, size_t input_index) {
MS_EXCEPTION_IF_NULL(user_node);
auto shape = AnfAlgo::GetInputDeviceShape(user_node, input_index);
@ -507,18 +524,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
// Data type of the tensor should be set as an attr of TupleToTensor op.
size_t input_index = GetInputNodeIndex(input, node);
auto abs = GenerateAbsByUserNodeInput(node, input_index);
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
tuple_to_tensor->set_abstract(abs);
// Data type of the tensor should be set as an attr of TupleToTensor op.
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
// Attr name is to be confirmed.
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
SetKernelInfoForNewCNode(tuple_to_tensor);
// Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
@ -571,18 +587,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
// Data type of the tensor should be set as an attr of TupleToTensor op.
size_t input_index = GetInputNodeIndex(input, node);
auto abs = GenerateAbsByUserNodeInput(node, input_index);
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
tuple_to_tensor->set_abstract(abs);
// Data type of the tensor should be set as an attr of TupleToTensor op.
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
// Attr name is to be confirmed.
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
SetKernelInfoForNewCNode(tuple_to_tensor);
return {tuple_to_tensor};
}
@ -599,12 +614,11 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(scalar_to_tensor);
// Set abstract for ScalarToTensor op according to user node's input shape and type.
size_t input_index = GetInputNodeIndex(input, node);
auto abs = GenerateAbsByUserNodeInput(node, input_index);
auto abs = GenerateAbsByOpInfer(prim::kPrimScalarToTensor, {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
scalar_to_tensor->set_abstract(abs);
SetKernelInfoForNewCNode(scalar_to_tensor);
return {scalar_to_tensor};
}
@ -621,9 +635,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &f
CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tensor_to_tuple);
// Set abstract for TensorToTuple op according to user node's input shape and type.
size_t input_index = GetInputNodeIndex(input, node);
auto abs = GenerateAbsByUserNodeInput(node, input_index);
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToTuple, {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
tensor_to_tuple->set_abstract(abs);
@ -644,24 +656,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &
CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tensor_to_scalar);
// Set abstract for TensorToScalar op according to user node's input shape and type.
size_t input_index = GetInputNodeIndex(input, node);
auto input_node = common::AnfAlgo::GetInputNode(node, input_index);
auto origin_input_abs = input_node->abstract();
MS_EXCEPTION_IF_NULL(origin_input_abs);
abstract::AbstractScalarPtr abs = nullptr;
if (input_node->isa<ValueNode>()) {
ValuePtr tensor_value = origin_input_abs->BuildValue();
if (!tensor_value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << "The abstract of " << input_node->DebugString() << " should be a tensor value.";
}
ValuePtr scalar_value = CreateValueFromTensor(tensor_value->cast<tensor::TensorPtr>());
MS_EXCEPTION_IF_NULL(scalar_value);
abs = std::make_shared<abstract::AbstractScalar>(scalar_value, scalar_value->type());
} else {
abs = std::make_shared<abstract::AbstractScalar>(
kAnyValue, TypeIdToType(common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0)));
}
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToScalar, {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
tensor_to_scalar->set_abstract(abs);

View File

@ -100,6 +100,9 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true
// Set kernel info for some value nodes manually.
void SetKernelInfoForValueNode(const ValueNodePtr &value_node);
// Multiplex op infer methods defined under core/ops to generate abstract of new cnode.
abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive);
// Generate abstract, format and object type for newly created node.
// They can be generated in multiple ways because new node is not processed by kernel selecting method.

View File

@ -361,9 +361,11 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) {
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{4, 2};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int64_t> shp_y{2};
std::vector<int64_t> shp_y{1, 3};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
AbstractBasePtrList abstract_list = {y_abstract};
auto y_tuple_abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
AbstractBasePtrList args_spec_list{x_abstract, y_tuple_abs};
auto func_graph = GetFuncGraph(g, args_spec_list);
ASSERT_TRUE(func_graph != nullptr);
AnfNodePtr reshape;