forked from mindspore-Ecosystem/mindspore
fix pyfunc scalar support
This commit is contained in:
parent
2b112ef656
commit
ef0161caa6
|
@ -229,25 +229,49 @@ bool PyFuncCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::v
|
|||
}
|
||||
|
||||
void PyFuncCpuKernel::BuildFuncInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<TypeId> in_types = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
std::vector<TypeId> out_types = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
std::vector<TypeId> in_types;
|
||||
std::vector<TypeId> out_types;
|
||||
std::vector<std::vector<int64_t>> in_shapes;
|
||||
std::vector<std::vector<int64_t>> out_shapes;
|
||||
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel_node); i++) {
|
||||
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> in_shape_tmp;
|
||||
std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
in_shapes.emplace_back(in_shape_tmp);
|
||||
if (AnfAlgo::HasNodeAttr("in_types", kernel_node)) {
|
||||
const auto &in_type_ptrs = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "in_types");
|
||||
std::for_each(in_type_ptrs.begin(), in_type_ptrs.end(),
|
||||
[&in_types](auto p) { in_types.emplace_back(p->type_id()); });
|
||||
} else {
|
||||
in_types = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_node); i++) {
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> out_shape_tmp;
|
||||
std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
out_shapes.emplace_back(out_shape_tmp);
|
||||
if (AnfAlgo::HasNodeAttr("out_types", kernel_node)) {
|
||||
const auto &out_type_ptrs = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "out_types");
|
||||
std::for_each(out_type_ptrs.begin(), out_type_ptrs.end(),
|
||||
[&out_types](auto p) { out_types.emplace_back(p->type_id()); });
|
||||
} else {
|
||||
out_types = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
}
|
||||
|
||||
if (AnfAlgo::HasNodeAttr("in_shapes", kernel_node)) {
|
||||
in_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "in_shapes");
|
||||
} else {
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel_node); i++) {
|
||||
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> in_shape_tmp;
|
||||
std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
in_shapes.emplace_back(in_shape_tmp);
|
||||
}
|
||||
}
|
||||
|
||||
if (AnfAlgo::HasNodeAttr("out_shapes", kernel_node)) {
|
||||
out_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "out_shapes");
|
||||
} else {
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_node); i++) {
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> out_shape_tmp;
|
||||
std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
out_shapes.emplace_back(out_shape_tmp);
|
||||
}
|
||||
}
|
||||
|
||||
if (in_shapes.size() != in_types.size()) {
|
||||
|
|
Loading…
Reference in New Issue