support dyn shape for aot custom op

fix code check problem

fix pynative infer

use python infer function in non-dyn  cases

fix cpp infer functions

add attr in infer function

fix code check problem

drop macosx platform

fix magic number

drop lite

fix magic number

drop lite

fix lite build

fix lite build

update custom op infer api

update custom op infer api

add copy right

move code to ccsrc

drop space

drop space

update reg info

update reg info

update file dir
This commit is contained in:
Zichun Ye 2023-02-15 14:09:43 +08:00
parent 38a314b934
commit 09e0c8abe8
8 changed files with 379 additions and 1 deletions

View File

@ -0,0 +1,151 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if !defined(_WIN32) && !defined(_WIN64)
#include <dlfcn.h>
#endif
#include "include/common/utils/utils.h"
#include "pybind_api/ir/primitive_py.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "utils/file_utils.h"
#include "utils/custom_aot_extra.h"
#include "mindspore/core/ops/custom.h"
namespace mindspore {
namespace ops {
#define REGISTER_PRIMITIVE_OP_CPP_INFER_IMPL(name, primitive, OP_INFER_ClASS, is_impl_infer_value) \
const auto helper_op_infer_##name = abstract::RegisterStandardPrimitiveEvalHelper( \
abstract::GetPrimitiveInferMapPtr(), primitive, std::make_shared<OP_INFER_ClASS>(), is_impl_infer_value);
class AGCustomInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
#if !defined(_WIN32) && !defined(_WIN64)
constexpr auto kFuncName = "func_name";
constexpr auto kAOTFuncType = "aot";
auto func_type = GetValue<std::string>(primitive->GetAttr(kAttrFuncType));
const auto &exec_info = GetValue<std::string>(primitive->GetAttr(kFuncName));
if (func_type != kAOTFuncType) {
MS_LOG(EXCEPTION) << "The custom operator of type '" << func_type
<< "' does not support dynamic shape yet, func name:" << exec_info;
}
auto kernel_name = primitive->name();
std::string file_path, func_name;
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
auto path = exec_info.substr(0, pos);
auto real_path = FileUtils::GetRealPath(path.c_str());
if (!real_path.has_value()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', couldn't find the AOT binary file under path: " << path;
}
file_path = real_path.value();
func_name = exec_info.substr(pos + 1);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', user defined function path '" << exec_info << "' is illegal.";
}
std::vector<int64_t *> input_shapes;
std::vector<int> ndims;
std::vector<std::vector<int64_t>> shape_list;
for (size_t idx = 0; idx < input_args.size(); idx++) {
auto params_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(kernel_name, input_args, idx);
MS_EXCEPTION_IF_NULL(params_shape_ptr);
auto params_shape = params_shape_ptr->shape();
ndims.push_back(SizeToInt(params_shape.size()));
(void)shape_list.emplace_back(params_shape);
}
(void)std::transform(std::begin(shape_list), std::end(shape_list), std::back_inserter(input_shapes),
[](auto &v) { return &v[0]; });
void *handle = dlopen(file_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (!handle) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', dlopen file under path" << file_path
<< "throw the error: " << dlerror();
}
AotExtraImpl attrs;
attrs.SetKernelPrim(primitive);
auto infer_func = reinterpret_cast<std::add_pointer<std::vector<int64_t>(int *, int64_t **, AotExtra *)>::type>(
dlsym(handle, (func_name + "InferShape").c_str()));
if (infer_func == nullptr) {
MS_LOG(EXCEPTION) << "Get infer shape functions failed. The custom operator does not support dynamic shape yet,"
<< " func name:" << func_name
<< ". Add the cpp version of the infer shape function to support dynamic shape.";
}
std::vector<int64_t> ret;
try {
ret = infer_func(&ndims[0], &input_shapes[0], (&attrs));
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "For " << kernel_name << ", operator failed when executing user defined file " << file_path
<< "! Error message is " << e.what();
}
if (handle != nullptr) {
dlclose(handle);
}
attrs.DestructKernelData();
return std::make_shared<abstract::Shape>(ret);
#else
MS_LOG(EXCEPTION) << "Custom Operators of type AOT doesn't support Windows currently";
return mindspore::abstract::kNoShape;
#endif
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_LOG(WARNING) << "This function is the fake infer dtype function and should not be entered. "
<< "Check the dtype of the output of the operator: "
<< GetValue<std::string>(primitive->GetAttr("func_name"));
return TypePtr();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
constexpr auto kCppInferShapeAttr = "cpp_infer_shape";
constexpr auto kDTypeAttr = "dtype";
if (!primitive->isa<PrimitivePy>()) {
MS_LOG(EXCEPTION) << "The prim is not a PrimitivePy. Prim name: " << primitive->name();
}
auto prim_py = dyn_cast<PrimitivePy>(primitive);
auto py_args = PreparePyInputs(input_args);
auto output = prim_py->RunInfer(py_args);
if (!primitive->HasAttr(kCppInferShapeAttr)) {
return abstract::PyInferRes2Abstract(prim_py, output);
}
auto res_dtype = output[kDTypeAttr].cast<TypePtr>();
if (res_dtype == nullptr) {
MS_LOG(EXCEPTION)
<< "For custom ops with cpp infer shape functions, we support the case that the output is a tensor."
<< "Thus the inferred dtype should be a type object, but get inferred dtype in: " << output;
}
auto shape = InferShape(primitive, input_args);
auto res = MakeAbstract(shape, res_dtype);
return res;
}
};
REGISTER_PRIMITIVE_OP_CPP_INFER_IMPL(Custom, prim::kPrimCustom, AGCustomInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -186,6 +186,25 @@ int CustomAOTCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
if (ret != 0) {
return ret;
}
shapes_.clear();
shape_list_.clear();
ndims_.clear();
for (size_t i = 0; i < inputs.size(); i++) {
auto in_shape = inputs[i]->GetShapeVector();
(void)shape_list_.emplace_back(in_shape);
ndims_.push_back(SizeToInt(in_shape.size()));
}
for (size_t i = 0; i < outputs.size(); i++) {
auto out_shape = outputs[i]->GetShapeVector();
(void)shape_list_.emplace_back(out_shape);
ndims_.push_back(SizeToInt(out_shape.size()));
}
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
[](auto &v) { return &v[0]; });
workspace_size_list_ = attrs_.WorkSpace();
return static_cast<int>(KRET_OK);
}

View File

@ -216,6 +216,26 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
if (ret != 0) {
return ret;
}
shapes_.clear();
shape_list_.clear();
ndims_.clear();
for (size_t i = 0; i < inputs.size(); i++) {
auto in_shape = inputs[i]->GetShapeVector();
(void)shape_list_.emplace_back(in_shape);
ndims_.push_back(SizeToInt(in_shape.size()));
}
for (size_t i = 0; i < outputs.size(); i++) {
auto out_shape = outputs[i]->GetShapeVector();
(void)shape_list_.emplace_back(out_shape);
ndims_.push_back(SizeToInt(out_shape.size()));
}
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
[](auto &v) { return &v[0]; });
workspace_size_list_ = attrs_.WorkSpace();
return static_cast<int>(KRET_OK);
}

View File

@ -480,6 +480,8 @@ class Custom(ops.PrimitiveWithInfer):
self.add_prim_attr("fn_id", func_id)
self.out_shape = out_shape
if self.out_shape is None and self.func_type == "aot":
self.add_prim_attr("cpp_infer_shape", True)
self.out_dtype = out_dtype
self.bprop = bprop
self.fake_output = False
@ -547,7 +549,12 @@ class Custom(ops.PrimitiveWithInfer):
logger.warning("{}, 'out_dtype' is an empty tuple. Add a placeholder instead. "
"Not recommend to use it as it could be any uninitialized data.".format(self.log_prefix))
infer_dtype = mstype.int32
if self.func_type == "aot":
if infer_shape is None:
logger.warning("{}, 'out_shape' is None. Add a placeholder instead. "
"A CPP version of infer shape function is required "
"in this case.".format(self.log_prefix))
infer_shape = (1,)
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
if not isinstance(infer_shape, (tuple, list)):
raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"

View File

@ -15,6 +15,16 @@
*/
#include <string.h>
#include <cstdint>
#include <vector>
#include "custom_aot_extra.h"
extern "C" std::vector<int64_t> CustomAddInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
const int64_t kDynRankSize = -2;
if (shapes[0][0] == kDynRankSize) {
return std::vector<int64_t>{shapes[0][0]};
}
return std::vector<int64_t>{shapes[0][0], shapes[0][1]};
}
extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
void *extra) {

View File

@ -13,6 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "custom_aot_extra.h"
extern "C" std::vector<int64_t> CustomAddInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
const int64_t kDynRankSize = -2;
if (shapes[0][0] == kDynRankSize) {
return std::vector<int64_t>{shapes[0][0]};
}
return std::vector<int64_t>{shapes[0][0], shapes[0][1]};
}
constexpr int THREADS = 1024;
__global__ void CustomAddKernel(float *input1, float *input2, float *output, size_t size) {
auto idx = blockIdx.x * THREADS + threadIdx.x;

View File

@ -0,0 +1,80 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "custom_aot_extra.h"
#include <iostream>
extern "C" std::vector<int64_t> CustomReduceInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
const int64_t kDynRankSize = -2;
if (shapes[0][0] == kDynRankSize) {
return std::vector<int64_t>{shapes[0][0]};
}
int64_t idx = extra->Attr<int64_t>("reduce_axis");
bool keep_dim = extra->Attr<bool>("keep_dim");
if (keep_dim) {
if (idx == 0) {
return std::vector<int64_t>{1, shapes[0][1]};
} else {
return std::vector<int64_t>{shapes[0][0], 1};
}
} else {
return std::vector<int64_t>{shapes[0][1 - idx]};
}
}
class reduce_kernel_idx : public AotKernelData {
public:
int64_t idx;
bool keep_dim;
};
extern "C" int CustomReduceInit(int *ndims, int64_t **shapes, const char **dtypes, AotExtra *extra) {
reduce_kernel_idx *kernel_data_ptr = new reduce_kernel_idx;
kernel_data_ptr->idx = extra->Attr<int64_t>("reduce_axis");
kernel_data_ptr->keep_dim = extra->Attr<bool>("keep_dim");
extra->SetKernelData(kernel_data_ptr);
return 0;
}
extern "C" int CustomReduce(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
void *extra_void) {
float *input1 = static_cast<float *>(params[0]);
float *output = static_cast<float *>(params[1]);
AotExtra *extra = static_cast<AotExtra *>(extra_void);
auto kernel_ptr = static_cast<reduce_kernel_idx *>(extra->KernelData());
bool keep_dim = kernel_ptr->keep_dim;
int64_t axis = kernel_ptr->idx;
int64_t input_dim_1 = shapes[0][1];
int size;
if (keep_dim) {
size = shapes[1][0] * shapes[1][1];
} else {
size = shapes[1][0];
}
int ext = shapes[0][axis];
for (int i = 0; i < size; i++) {
output[i] = 0;
for (int j = 0; j < ext; j++) {
int idx = input_dim_1 * (i * axis + j * (1 - axis)) + i * (1 - axis) + j * axis;
output[i] = output[i] + input1[idx];
}
}
return 0;
}

View File

@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
from mindspore.nn import Cell
import mindspore.ops as ops
from mindspore.ops import DataType, CustomRegOp
from mindspore.ops.operations import _inner_ops as inner
class AOTSingleOutputNet(Cell):
@ -45,6 +46,19 @@ class AOTSingleOutputWithAttrNet(Cell):
return self.program(x, y, 0.7)
class AOTSingleOutputDynNet(Cell):
def __init__(self, func, out_types, reg=None):
super(AOTSingleOutputDynNet, self).__init__()
self.program = ops.Custom(func, None, out_types, "aot", reg_info=reg)
self.convert_to_dynamic = inner.ConvertToDynamic(
is_dynamic_rank=True).add_prim_attr("primitive_target", "CPU")
def construct(self, x, y):
x = self.convert_to_dynamic(x)
return self.program(x, y)
def get_cuda_bare_metal_version():
raw_output = subprocess.check_output(["nvcc", "-V"],
universal_newlines=True)
@ -116,6 +130,19 @@ def aot_single_output_auto_compile(source_name, reg):
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
def aot_single_output_dyn_shape(source_name, reg):
shape = (4, 5)
input_x = np.random.normal(0, 1, shape).astype(np.float32)
input_y = np.random.normal(0, 1, shape).astype(np.float32)
dir_path = os.path.dirname(os.path.abspath(__file__))
func_path = dir_path + "/aot_test_files/" + source_name
test = AOTSingleOutputDynNet(func_path + ":CustomAdd", mstype.float32, reg)
output = test(Tensor(input_x), Tensor(input_y))
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
def aot_single_output_with_attr(source_name, reg):
shape = (4, 5)
input_x = np.random.normal(0, 1, shape).astype(np.float32)
@ -177,6 +204,7 @@ def test_aot_single_output_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
aot_single_output(get_file_path_gpu, "add.cu", "add.so", None)
aot_single_output_auto_compile("add.cu", None)
aot_single_output_dyn_shape("add.cu", None)
v_major, v_mid, v_minor = get_cuda_bare_metal_version()
if v_major >= 11 or (v_mid >= 1 and v_minor >= 168):
aot_single_output_with_attr("add_with_attr.cu", add_gpu_info)
@ -228,6 +256,57 @@ def test_aot_single_output_cpu():
aot_single_output_with_attr_only("add_with_attr.cc", add_cpu_info_attr_only)
class ReduceDynNet(Cell):
def __init__(self, func, out_types, axis, keep_dim):
super(ReduceDynNet, self).__init__()
reduce_cpu_info = CustomRegOp("reduce_kernel_cpu") \
.input(0, "x1") \
.output(0, "y") \
.dtype_format(DataType.None_None, DataType.None_None) \
.attr("reduce_axis", "required", "float", value=axis) \
.attr("keep_dim", "required", "bool", value=keep_dim) \
.target("CPU") \
.get_op_info()
self.program = ops.Custom(func, None, out_types, "aot", reg_info=reduce_cpu_info)
self.convert_to_dynamic = inner.ConvertToDynamic(
is_dynamic_rank=True).add_prim_attr("primitive_target", "CPU")
def construct(self, x):
x = self.convert_to_dynamic(x)
return self.program(x)
def aot_reduce_dyn_shape(source_name):
shape = (4, 5)
axis = 1
keep_dim = False
input_x = np.random.normal(0, 1, shape).astype(np.float32)
dir_path = os.path.dirname(os.path.abspath(__file__))
func_path = dir_path + "/aot_test_files/" + source_name
test = ReduceDynNet(func_path + ":CustomReduce", mstype.float32, axis, keep_dim)
output = test(Tensor(input_x))
assert np.allclose(np.sum(input_x, axis, keepdims=keep_dim), output.asnumpy(), 0.001, 0.001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_aot_single_output_cpu_dyn_shape():
"""
Feature: custom aot operator, multiple inputs, single output, CPU, GRAPH_MODE
Description: pre-compile xxx.cc to xxx.so, custom operator launches xxx.so
Expectation: nn result matches numpy result
"""
sys = platform.system()
if sys.lower() in {"windows", "darwin"}:
pass
else:
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
aot_single_output_dyn_shape("add.cc", add_cpu_info)
aot_reduce_dyn_shape("reduce.cc")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard