fix pyfunc scalar support

This commit is contained in:
Yang Jiao 2021-11-10 16:54:43 +08:00
parent 2b112ef656
commit ef0161caa6
1 changed files with 38 additions and 14 deletions

View File

@ -229,25 +229,49 @@ bool PyFuncCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::v
}
void PyFuncCpuKernel::BuildFuncInfo(const CNodePtr &kernel_node) {
std::vector<TypeId> in_types = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
std::vector<TypeId> out_types = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
std::vector<TypeId> in_types;
std::vector<TypeId> out_types;
std::vector<std::vector<int64_t>> in_shapes;
std::vector<std::vector<int64_t>> out_shapes;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel_node); i++) {
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
in_shapes.emplace_back(in_shape_tmp);
if (AnfAlgo::HasNodeAttr("in_types", kernel_node)) {
const auto &in_type_ptrs = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "in_types");
std::for_each(in_type_ptrs.begin(), in_type_ptrs.end(),
[&in_types](auto p) { in_types.emplace_back(p->type_id()); });
} else {
in_types = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
}
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_node); i++) {
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
out_shapes.emplace_back(out_shape_tmp);
if (AnfAlgo::HasNodeAttr("out_types", kernel_node)) {
const auto &out_type_ptrs = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "out_types");
std::for_each(out_type_ptrs.begin(), out_type_ptrs.end(),
[&out_types](auto p) { out_types.emplace_back(p->type_id()); });
} else {
out_types = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
}
if (AnfAlgo::HasNodeAttr("in_shapes", kernel_node)) {
in_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "in_shapes");
} else {
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel_node); i++) {
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
in_shapes.emplace_back(in_shape_tmp);
}
}
if (AnfAlgo::HasNodeAttr("out_shapes", kernel_node)) {
out_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "out_shapes");
} else {
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_node); i++) {
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
out_shapes.emplace_back(out_shape_tmp);
}
}
if (in_shapes.size() != in_types.size()) {