forked from mindspore-Ecosystem/mindspore
!48924 Support dynamic shape for AOT custom op
Merge pull request !48924 from zichun_ye/custom_op_dyn
This commit is contained in:
commit
d100945336
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pybind_api/ir/primitive_py.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/custom_aot_extra.h"
|
||||
#include "mindspore/core/ops/custom.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
#define REGISTER_PRIMITIVE_OP_CPP_INFER_IMPL(name, primitive, OP_INFER_ClASS, is_impl_infer_value) \
|
||||
const auto helper_op_infer_##name = abstract::RegisterStandardPrimitiveEvalHelper( \
|
||||
abstract::GetPrimitiveInferMapPtr(), primitive, std::make_shared<OP_INFER_ClASS>(), is_impl_infer_value);
|
||||
|
||||
class AGCustomInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
constexpr auto kFuncName = "func_name";
|
||||
constexpr auto kAOTFuncType = "aot";
|
||||
auto func_type = GetValue<std::string>(primitive->GetAttr(kAttrFuncType));
|
||||
const auto &exec_info = GetValue<std::string>(primitive->GetAttr(kFuncName));
|
||||
if (func_type != kAOTFuncType) {
|
||||
MS_LOG(EXCEPTION) << "The custom operator of type '" << func_type
|
||||
<< "' does not support dynamic shape yet, func name:" << exec_info;
|
||||
}
|
||||
|
||||
auto kernel_name = primitive->name();
|
||||
std::string file_path, func_name;
|
||||
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
auto path = exec_info.substr(0, pos);
|
||||
auto real_path = FileUtils::GetRealPath(path.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', couldn't find the AOT binary file under path: " << path;
|
||||
}
|
||||
file_path = real_path.value();
|
||||
func_name = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', user defined function path '" << exec_info << "' is illegal.";
|
||||
}
|
||||
|
||||
std::vector<int64_t *> input_shapes;
|
||||
std::vector<int> ndims;
|
||||
|
||||
std::vector<std::vector<int64_t>> shape_list;
|
||||
|
||||
for (size_t idx = 0; idx < input_args.size(); idx++) {
|
||||
auto params_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(kernel_name, input_args, idx);
|
||||
MS_EXCEPTION_IF_NULL(params_shape_ptr);
|
||||
auto params_shape = params_shape_ptr->shape();
|
||||
ndims.push_back(SizeToInt(params_shape.size()));
|
||||
(void)shape_list.emplace_back(params_shape);
|
||||
}
|
||||
(void)std::transform(std::begin(shape_list), std::end(shape_list), std::back_inserter(input_shapes),
|
||||
[](auto &v) { return &v[0]; });
|
||||
|
||||
void *handle = dlopen(file_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', dlopen file under path" << file_path
|
||||
<< "throw the error: " << dlerror();
|
||||
}
|
||||
AotExtraImpl attrs;
|
||||
attrs.SetKernelPrim(primitive);
|
||||
|
||||
auto infer_func = reinterpret_cast<std::add_pointer<std::vector<int64_t>(int *, int64_t **, AotExtra *)>::type>(
|
||||
dlsym(handle, (func_name + "InferShape").c_str()));
|
||||
if (infer_func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get infer shape functions failed. The custom operator does not support dynamic shape yet,"
|
||||
<< " func name:" << func_name
|
||||
<< ". Add the cpp version of the infer shape function to support dynamic shape.";
|
||||
}
|
||||
|
||||
std::vector<int64_t> ret;
|
||||
try {
|
||||
ret = infer_func(&ndims[0], &input_shapes[0], (&attrs));
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name << ", operator failed when executing user defined file " << file_path
|
||||
<< "! Error message is " << e.what();
|
||||
}
|
||||
|
||||
if (handle != nullptr) {
|
||||
dlclose(handle);
|
||||
}
|
||||
attrs.DestructKernelData();
|
||||
return std::make_shared<abstract::Shape>(ret);
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Custom Operators of type AOT doesn't support Windows currently";
|
||||
return mindspore::abstract::kNoShape;
|
||||
#endif
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_LOG(WARNING) << "This function is the fake infer dtype function and should not be entered. "
|
||||
<< "Check the dtype of the output of the operator: "
|
||||
<< GetValue<std::string>(primitive->GetAttr("func_name"));
|
||||
|
||||
return TypePtr();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
constexpr auto kCppInferShapeAttr = "cpp_infer_shape";
|
||||
constexpr auto kDTypeAttr = "dtype";
|
||||
if (!primitive->isa<PrimitivePy>()) {
|
||||
MS_LOG(EXCEPTION) << "The prim is not a PrimitivePy. Prim name: " << primitive->name();
|
||||
}
|
||||
auto prim_py = dyn_cast<PrimitivePy>(primitive);
|
||||
auto py_args = PreparePyInputs(input_args);
|
||||
auto output = prim_py->RunInfer(py_args);
|
||||
|
||||
if (!primitive->HasAttr(kCppInferShapeAttr)) {
|
||||
return abstract::PyInferRes2Abstract(prim_py, output);
|
||||
}
|
||||
|
||||
auto res_dtype = output[kDTypeAttr].cast<TypePtr>();
|
||||
if (res_dtype == nullptr) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For custom ops with cpp infer shape functions, we support the case that the output is a tensor."
|
||||
<< "Thus the inferred dtype should be a type object, but get inferred dtype in: " << output;
|
||||
}
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto res = MakeAbstract(shape, res_dtype);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_PRIMITIVE_OP_CPP_INFER_IMPL(Custom, prim::kPrimCustom, AGCustomInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -186,6 +186,25 @@ int CustomAOTCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
shapes_.clear();
|
||||
shape_list_.clear();
|
||||
ndims_.clear();
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto in_shape = inputs[i]->GetShapeVector();
|
||||
(void)shape_list_.emplace_back(in_shape);
|
||||
ndims_.push_back(SizeToInt(in_shape.size()));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
auto out_shape = outputs[i]->GetShapeVector();
|
||||
(void)shape_list_.emplace_back(out_shape);
|
||||
ndims_.push_back(SizeToInt(out_shape.size()));
|
||||
}
|
||||
|
||||
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
[](auto &v) { return &v[0]; });
|
||||
workspace_size_list_ = attrs_.WorkSpace();
|
||||
return static_cast<int>(KRET_OK);
|
||||
}
|
||||
|
|
|
@ -216,6 +216,26 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
shapes_.clear();
|
||||
shape_list_.clear();
|
||||
ndims_.clear();
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto in_shape = inputs[i]->GetShapeVector();
|
||||
(void)shape_list_.emplace_back(in_shape);
|
||||
ndims_.push_back(SizeToInt(in_shape.size()));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
auto out_shape = outputs[i]->GetShapeVector();
|
||||
(void)shape_list_.emplace_back(out_shape);
|
||||
ndims_.push_back(SizeToInt(out_shape.size()));
|
||||
}
|
||||
|
||||
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
[](auto &v) { return &v[0]; });
|
||||
|
||||
workspace_size_list_ = attrs_.WorkSpace();
|
||||
return static_cast<int>(KRET_OK);
|
||||
}
|
||||
|
|
|
@ -480,6 +480,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.add_prim_attr("fn_id", func_id)
|
||||
|
||||
self.out_shape = out_shape
|
||||
if self.out_shape is None and self.func_type == "aot":
|
||||
self.add_prim_attr("cpp_infer_shape", True)
|
||||
self.out_dtype = out_dtype
|
||||
self.bprop = bprop
|
||||
self.fake_output = False
|
||||
|
@ -547,7 +549,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
logger.warning("{}, 'out_dtype' is an empty tuple. Add a placeholder instead. "
|
||||
"Not recommend to use it as it could be any uninitialized data.".format(self.log_prefix))
|
||||
infer_dtype = mstype.int32
|
||||
|
||||
if self.func_type == "aot":
|
||||
if infer_shape is None:
|
||||
logger.warning("{}, 'out_shape' is None. Add a placeholder instead. "
|
||||
"A CPP version of infer shape function is required "
|
||||
"in this case.".format(self.log_prefix))
|
||||
infer_shape = (1,)
|
||||
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
|
||||
if not isinstance(infer_shape, (tuple, list)):
|
||||
raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"
|
||||
|
|
|
@ -15,6 +15,16 @@
|
|||
*/
|
||||
#include <string.h>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include "custom_aot_extra.h"
|
||||
|
||||
extern "C" std::vector<int64_t> CustomAddInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
|
||||
const int64_t kDynRankSize = -2;
|
||||
if (shapes[0][0] == kDynRankSize) {
|
||||
return std::vector<int64_t>{shapes[0][0]};
|
||||
}
|
||||
return std::vector<int64_t>{shapes[0][0], shapes[0][1]};
|
||||
}
|
||||
|
||||
extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra) {
|
||||
|
|
|
@ -13,6 +13,18 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "custom_aot_extra.h"
|
||||
|
||||
extern "C" std::vector<int64_t> CustomAddInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
|
||||
const int64_t kDynRankSize = -2;
|
||||
if (shapes[0][0] == kDynRankSize) {
|
||||
return std::vector<int64_t>{shapes[0][0]};
|
||||
}
|
||||
return std::vector<int64_t>{shapes[0][0], shapes[0][1]};
|
||||
}
|
||||
|
||||
constexpr int THREADS = 1024;
|
||||
__global__ void CustomAddKernel(float *input1, float *input2, float *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "custom_aot_extra.h"
|
||||
#include <iostream>
|
||||
|
||||
extern "C" std::vector<int64_t> CustomReduceInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
|
||||
const int64_t kDynRankSize = -2;
|
||||
|
||||
if (shapes[0][0] == kDynRankSize) {
|
||||
return std::vector<int64_t>{shapes[0][0]};
|
||||
}
|
||||
int64_t idx = extra->Attr<int64_t>("reduce_axis");
|
||||
bool keep_dim = extra->Attr<bool>("keep_dim");
|
||||
if (keep_dim) {
|
||||
if (idx == 0) {
|
||||
return std::vector<int64_t>{1, shapes[0][1]};
|
||||
} else {
|
||||
return std::vector<int64_t>{shapes[0][0], 1};
|
||||
}
|
||||
} else {
|
||||
return std::vector<int64_t>{shapes[0][1 - idx]};
|
||||
}
|
||||
}
|
||||
|
||||
class reduce_kernel_idx : public AotKernelData {
|
||||
public:
|
||||
int64_t idx;
|
||||
bool keep_dim;
|
||||
};
|
||||
|
||||
extern "C" int CustomReduceInit(int *ndims, int64_t **shapes, const char **dtypes, AotExtra *extra) {
|
||||
reduce_kernel_idx *kernel_data_ptr = new reduce_kernel_idx;
|
||||
kernel_data_ptr->idx = extra->Attr<int64_t>("reduce_axis");
|
||||
kernel_data_ptr->keep_dim = extra->Attr<bool>("keep_dim");
|
||||
extra->SetKernelData(kernel_data_ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C" int CustomReduce(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra_void) {
|
||||
float *input1 = static_cast<float *>(params[0]);
|
||||
float *output = static_cast<float *>(params[1]);
|
||||
AotExtra *extra = static_cast<AotExtra *>(extra_void);
|
||||
|
||||
auto kernel_ptr = static_cast<reduce_kernel_idx *>(extra->KernelData());
|
||||
|
||||
bool keep_dim = kernel_ptr->keep_dim;
|
||||
int64_t axis = kernel_ptr->idx;
|
||||
int64_t input_dim_1 = shapes[0][1];
|
||||
|
||||
int size;
|
||||
if (keep_dim) {
|
||||
size = shapes[1][0] * shapes[1][1];
|
||||
} else {
|
||||
size = shapes[1][0];
|
||||
}
|
||||
int ext = shapes[0][axis];
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = 0;
|
||||
for (int j = 0; j < ext; j++) {
|
||||
int idx = input_dim_1 * (i * axis + j * (1 - axis)) + i * (1 - axis) + j * axis;
|
||||
output[i] = output[i] + input1[idx];
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import DataType, CustomRegOp
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
||||
class AOTSingleOutputNet(Cell):
|
||||
|
@ -45,6 +46,19 @@ class AOTSingleOutputWithAttrNet(Cell):
|
|||
return self.program(x, y, 0.7)
|
||||
|
||||
|
||||
class AOTSingleOutputDynNet(Cell):
|
||||
def __init__(self, func, out_types, reg=None):
|
||||
super(AOTSingleOutputDynNet, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, None, out_types, "aot", reg_info=reg)
|
||||
self.convert_to_dynamic = inner.ConvertToDynamic(
|
||||
is_dynamic_rank=True).add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.convert_to_dynamic(x)
|
||||
return self.program(x, y)
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version():
|
||||
raw_output = subprocess.check_output(["nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
|
@ -116,6 +130,19 @@ def aot_single_output_auto_compile(source_name, reg):
|
|||
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
def aot_single_output_dyn_shape(source_name, reg):
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
func_path = dir_path + "/aot_test_files/" + source_name
|
||||
|
||||
test = AOTSingleOutputDynNet(func_path + ":CustomAdd", mstype.float32, reg)
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
|
||||
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
def aot_single_output_with_attr(source_name, reg):
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
|
@ -177,6 +204,7 @@ def test_aot_single_output_gpu():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
aot_single_output(get_file_path_gpu, "add.cu", "add.so", None)
|
||||
aot_single_output_auto_compile("add.cu", None)
|
||||
aot_single_output_dyn_shape("add.cu", None)
|
||||
v_major, v_mid, v_minor = get_cuda_bare_metal_version()
|
||||
if v_major >= 11 or (v_mid >= 1 and v_minor >= 168):
|
||||
aot_single_output_with_attr("add_with_attr.cu", add_gpu_info)
|
||||
|
@ -228,6 +256,57 @@ def test_aot_single_output_cpu():
|
|||
aot_single_output_with_attr_only("add_with_attr.cc", add_cpu_info_attr_only)
|
||||
|
||||
|
||||
class ReduceDynNet(Cell):
|
||||
def __init__(self, func, out_types, axis, keep_dim):
|
||||
super(ReduceDynNet, self).__init__()
|
||||
reduce_cpu_info = CustomRegOp("reduce_kernel_cpu") \
|
||||
.input(0, "x1") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.attr("reduce_axis", "required", "float", value=axis) \
|
||||
.attr("keep_dim", "required", "bool", value=keep_dim) \
|
||||
.target("CPU") \
|
||||
.get_op_info()
|
||||
self.program = ops.Custom(func, None, out_types, "aot", reg_info=reduce_cpu_info)
|
||||
self.convert_to_dynamic = inner.ConvertToDynamic(
|
||||
is_dynamic_rank=True).add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.convert_to_dynamic(x)
|
||||
return self.program(x)
|
||||
|
||||
|
||||
def aot_reduce_dyn_shape(source_name):
|
||||
shape = (4, 5)
|
||||
axis = 1
|
||||
keep_dim = False
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
func_path = dir_path + "/aot_test_files/" + source_name
|
||||
|
||||
test = ReduceDynNet(func_path + ":CustomReduce", mstype.float32, axis, keep_dim)
|
||||
output = test(Tensor(input_x))
|
||||
assert np.allclose(np.sum(input_x, axis, keepdims=keep_dim), output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_aot_single_output_cpu_dyn_shape():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs, single output, CPU, GRAPH_MODE
|
||||
Description: pre-compile xxx.cc to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
sys = platform.system()
|
||||
if sys.lower() in {"windows", "darwin"}:
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
aot_single_output_dyn_shape("add.cc", add_cpu_info)
|
||||
aot_reduce_dyn_shape("reduce.cc")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue