forked from mindspore-Ecosystem/mindspore
Multiplex op infer
This commit is contained in:
parent
f484e029a1
commit
bde4255041
|
@ -21,8 +21,8 @@
|
|||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -240,6 +240,23 @@ void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
|
|||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto found = abstract::GetPrimitiveInferImpl(primitive);
|
||||
if (!found.has_value()) {
|
||||
MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
|
||||
}
|
||||
|
||||
std::vector<AbstractBasePtr> input_args;
|
||||
std::for_each(input_list.begin(), input_list.end(),
|
||||
[&input_args](const auto &input) { input_args.emplace_back(input->abstract()); });
|
||||
auto infer_impl = found.value();
|
||||
auto abs = infer_impl.InferShapeAndType(nullptr, primitive, input_args);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
|
||||
return abs;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr GenerateAbsByUserNodeInput(const CNodePtr &user_node, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(user_node, input_index);
|
||||
|
@ -507,18 +524,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto abs = GenerateAbsByUserNodeInput(node, input_index);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// Attr name is to be confirmed.
|
||||
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
|
||||
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
// Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
|
||||
KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
|
||||
|
@ -571,18 +587,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
|
|||
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto abs = GenerateAbsByUserNodeInput(node, input_index);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// Attr name is to be confirmed.
|
||||
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
|
||||
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
return {tuple_to_tensor};
|
||||
}
|
||||
|
@ -599,12 +614,11 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
|
|||
CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(scalar_to_tensor);
|
||||
|
||||
// Set abstract for ScalarToTensor op according to user node's input shape and type.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto abs = GenerateAbsByUserNodeInput(node, input_index);
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimScalarToTensor, {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
|
||||
scalar_to_tensor->set_abstract(abs);
|
||||
|
||||
SetKernelInfoForNewCNode(scalar_to_tensor);
|
||||
return {scalar_to_tensor};
|
||||
}
|
||||
|
@ -621,9 +635,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &f
|
|||
CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_tuple);
|
||||
|
||||
// Set abstract for TensorToTuple op according to user node's input shape and type.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto abs = GenerateAbsByUserNodeInput(node, input_index);
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToTuple, {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
|
||||
tensor_to_tuple->set_abstract(abs);
|
||||
|
@ -644,24 +656,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &
|
|||
CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_scalar);
|
||||
|
||||
// Set abstract for TensorToScalar op according to user node's input shape and type.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto input_node = common::AnfAlgo::GetInputNode(node, input_index);
|
||||
auto origin_input_abs = input_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(origin_input_abs);
|
||||
abstract::AbstractScalarPtr abs = nullptr;
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
ValuePtr tensor_value = origin_input_abs->BuildValue();
|
||||
if (!tensor_value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(EXCEPTION) << "The abstract of " << input_node->DebugString() << " should be a tensor value.";
|
||||
}
|
||||
ValuePtr scalar_value = CreateValueFromTensor(tensor_value->cast<tensor::TensorPtr>());
|
||||
MS_EXCEPTION_IF_NULL(scalar_value);
|
||||
abs = std::make_shared<abstract::AbstractScalar>(scalar_value, scalar_value->type());
|
||||
} else {
|
||||
abs = std::make_shared<abstract::AbstractScalar>(
|
||||
kAnyValue, TypeIdToType(common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0)));
|
||||
}
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToScalar, {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
|
||||
tensor_to_scalar->set_abstract(abs);
|
||||
|
|
|
@ -100,6 +100,9 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true
|
|||
// Set kernel info for some value nodes manually.
|
||||
void SetKernelInfoForValueNode(const ValueNodePtr &value_node);
|
||||
|
||||
// Multiplex op infer methods defined under core/ops to generate abstract of new cnode.
|
||||
abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive);
|
||||
|
||||
// Generate abstract, format and object type for newly created node.
|
||||
// They can be generated in multiple ways because new node is not processed by kernel selecting method.
|
||||
|
||||
|
|
|
@ -361,9 +361,11 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) {
|
|||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{4, 2};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{2};
|
||||
std::vector<int64_t> shp_y{1, 3};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
AbstractBasePtrList abstract_list = {y_abstract};
|
||||
auto y_tuple_abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_tuple_abs};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr reshape;
|
||||
|
|
Loading…
Reference in New Issue