forked from mindspore-Ecosystem/mindspore
change aot func return code
This commit is contained in:
parent
e318954f06
commit
7d4599e65a
|
@ -39,24 +39,27 @@ CustomAOTCpuKernelMod::~CustomAOTCpuKernelMod() {
|
|||
}
|
||||
|
||||
void CustomAOTCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
auto path = exec_info.substr(0, pos);
|
||||
auto real_path = FileUtils::GetRealPath(path.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, couldn't find the AOT binary file: " << path;
|
||||
}
|
||||
file_path_ = real_path.value();
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For '" << kernel_name_ << "' on CPU, user defined function path '" << exec_info
|
||||
<< "' is illegal. Proper function path should follow the format of 'dir_path/file_name:func_name'";
|
||||
}
|
||||
|
||||
num_input_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_type_list = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
if (num_input_ != input_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input shapes'size is " << num_input_ << ", while input types' size is "
|
||||
<< input_type_list.size();
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, number of input types '" << input_type_list.size()
|
||||
<< "' doesn't match number of input shapes '" << num_input_ << "'";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
|
@ -72,8 +75,8 @@ void CustomAOTCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
auto output_type_list = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
if (num_output_ != output_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Output shapes'size is " << num_output_ << ", while output types' size is "
|
||||
<< output_type_list.size();
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, number of outputs types '" << output_type_list.size()
|
||||
<< "' doesn't match number of output shapes '" << num_output_ << "'";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
|
@ -107,7 +110,8 @@ bool CustomAOTCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
if (!handle_) {
|
||||
handle_ = dlopen(file_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle_) {
|
||||
MS_LOG(EXCEPTION) << "Open Error: " << dlerror();
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, dlopen file '" << file_path_
|
||||
<< "' should be successful, but error occurs! Error message is: " << dlerror();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,7 +120,8 @@ bool CustomAOTCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
reinterpret_cast<std::add_pointer<int(int, void **, int *, int64_t **, const char **, void *, void *)>::type>(
|
||||
dlsym(handle_, func_name_.c_str()));
|
||||
if (auto error_info = dlerror(); error_info != nullptr) {
|
||||
MS_LOG(EXCEPTION) << error_info;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, error occurs when fetching function '" << func_name_
|
||||
<< "'. Error info: " << error_info;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,21 +134,16 @@ bool CustomAOTCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
ret = aot_func_(nparam, ¶ms[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], nullptr, nullptr);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "CustomAOT operator failed when running user defined file " << file_path_ << "! "
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on CPU, operator failed when executing user defined file "
|
||||
<< file_path_ << "! "
|
||||
<< "Error message is " << e.what();
|
||||
}
|
||||
|
||||
switch (ret) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
MS_LOG(EXCEPTION) << "Number of parameters passed to AOT kernel is " << nparam
|
||||
<< ", inconsistent with what the user wants";
|
||||
case 2:
|
||||
MS_LOG(EXCEPTION) << "Type of parameters passed to AOT kernel is inconsistent with what the user wants";
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Error occurred when running AOT kernel, "
|
||||
<< "error id is " << ret;
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Return value from CPU AOT kernel(" << file_path_ << ")'s function(" << func_name_ << ") is "
|
||||
<< ret << ". "
|
||||
<< "Any return value not equal to 0 will be treated as user defined error code and we will "
|
||||
"terminate execution. If termination is not your purpose, please set return value to 0.";
|
||||
}
|
||||
|
||||
#else
|
||||
|
|
|
@ -52,7 +52,8 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (!handle_) {
|
||||
handle_ = dlopen(file_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', open should be successful, but error, " << dlerror();
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' on GPU, dlopen file '" << file_path_
|
||||
<< "' should be successful, but error occurs! Error message is: " << dlerror();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -62,7 +63,8 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
reinterpret_cast<std::add_pointer<int(int, void **, int *, int64_t **, const char **, void *, void *)>::type>(
|
||||
dlsym(handle_, func_name_.c_str()));
|
||||
if (auto error_info = dlerror(); error_info != nullptr) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', error info: " << error_info;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' on GPU, error occurs when fetching function '" << func_name_
|
||||
<< "'. Error info: " << error_info;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -76,27 +78,17 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
ret = aot_func_(nparam, ¶ms[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], stream_ptr, nullptr);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', operator failed when running user defined file " << file_path_
|
||||
<< "! "
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' on GPU, operator failed when executing user defined file "
|
||||
<< file_path_ << "! "
|
||||
<< "Error message is " << e.what();
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (ret) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of parameters passed to AOT kernel is " << nparam
|
||||
<< ", inconsistent with what the user wants";
|
||||
return false;
|
||||
case 2:
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', type of parameters passed to AOT kernel is inconsistent with what the user wants";
|
||||
return false;
|
||||
default:
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', error occurred when running AOT kernel, "
|
||||
<< "error id is " << ret;
|
||||
return false;
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Return value from GPU AOT kernel(" << file_path_ << ")'s function(" << func_name_ << ") is "
|
||||
<< ret << ". "
|
||||
<< "Any return value not equal to 0 will be treated as user defined error code and we will "
|
||||
"terminate execution. If termination is not your purpose, please set return value to 0.";
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -109,19 +101,21 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
auto path = exec_info.substr(0, pos);
|
||||
auto real_path = FileUtils::GetRealPath(path.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the file path should be exist, but got " << path;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on GPU, couldn't find the AOT binary file: " << path;
|
||||
}
|
||||
file_path_ = real_path.value();
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', Wrong execute info:" << exec_info;
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For '" << kernel_name_ << "' on GPU, user defined function path '" << exec_info
|
||||
<< "' is illegal. Proper function path should follow the format of 'dir_path/file_name:func_name'";
|
||||
}
|
||||
|
||||
num_input_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_type_list = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
if (num_input_ != input_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << input_type_list.size()
|
||||
<< ", but got " << num_input_;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on GPU, number of input types '" << input_type_list.size()
|
||||
<< "' doesn't match number of input shapes '" << num_input_ << "'";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
|
@ -138,8 +132,8 @@ class CustomAOTGpuKernelMod : public NativeGpuKernelMod {
|
|||
auto output_type_list = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
|
||||
if (num_output_ != output_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs should be " << output_type_list.size()
|
||||
<< ", but got " << num_output_;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' on GPU, number of outputs types '" << output_type_list.size()
|
||||
<< "' doesn't match number of output shapes '" << num_output_ << "'";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -91,16 +91,16 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
|
||||
Return Value(int):
|
||||
|
||||
- 0: raise no Exception
|
||||
- larger than 0: will raise Exception
|
||||
- 0: MindSpore will continue to run if this aot kernel is successfully executed
|
||||
- others: MindSpore will raise exception and exit
|
||||
|
||||
Examples: see details tests/st/ops/graph_kernel/custom/aot_test_files/
|
||||
Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/
|
||||
|
||||
- Use it in Custom:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Custom(func="{path}/{file_name}:{func_name}",...)
|
||||
Custom(func="{dir_path}/{file_name}:{func_name}",...)
|
||||
(ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32))
|
||||
|
||||
out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,9 +22,7 @@ extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes
|
|||
constexpr int TOTAL_PARAM_NUM = 3;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// In this case, there are two inputs and one output, so the nparam should be 3.
|
||||
|
@ -56,6 +54,6 @@ extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes
|
|||
output[i] = input1[i] + input2[i];
|
||||
}
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,9 +29,7 @@ extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes
|
|||
constexpr int TOTAL_PARAM_NUM = 3;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are two inputs and one output, so the nparam should be 3.
|
||||
|
@ -63,6 +61,6 @@ extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes
|
|||
CustomAddKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input2),
|
||||
static_cast<float *>(output), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,9 +34,7 @@ extern "C" int CustomAddMulDiv(int nparam, void **params, int *ndims, int64_t **
|
|||
constexpr int TOTAL_PARAM_NUM = 5;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are two inputs and three outputs, so the nparam should be 5.
|
||||
|
@ -69,6 +67,6 @@ extern "C" int CustomAddMulDiv(int nparam, void **params, int *ndims, int64_t **
|
|||
CustomAddMulDivKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input2),
|
||||
static_cast<float *>(output1), static_cast<float *>(output2),
|
||||
static_cast<float *>(output3), size);
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -33,9 +33,7 @@ extern "C" int CustomAddMulDivBprop(int nparam, void **params, int *ndims, int64
|
|||
constexpr int TOTAL_PARAM_NUM = 7;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are five inputs and two outputs, so the nparam should be 7.
|
||||
|
@ -74,6 +72,6 @@ extern "C" int CustomAddMulDivBprop(int nparam, void **params, int *ndims, int64
|
|||
static_cast<float *>(input4), static_cast<float *>(input5), static_cast<float *>(output1),
|
||||
static_cast<float *>(output2), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,9 +32,7 @@ extern "C" int CustomHSquareMul(int nparam, void **params, int *ndims, int64_t *
|
|||
constexpr int TOTAL_PARAM_NUM = 3;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are two inputs and one output, so the nparam should be 3.
|
||||
|
@ -70,6 +68,6 @@ extern "C" int CustomHSquareMul(int nparam, void **params, int *ndims, int64_t *
|
|||
CustomHSquareMulKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<half *>(input2),
|
||||
static_cast<half *>(output), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,16 +32,14 @@ extern "C" int CustomReorganize(int nparam, void **params, int *ndims, int64_t *
|
|||
constexpr int TOTAL_PARAM_NUM = 3;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are two inputs and one output, so the nparam should be 3.
|
||||
if (nparam != TOTAL_PARAM_NUM) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
// This is to check if the type of parameters the same as what the user wants.
|
||||
if (strcmp(dtypes[0], "float32") != 0) {
|
||||
return 2;
|
||||
|
@ -71,6 +69,6 @@ extern "C" int CustomReorganize(int nparam, void **params, int *ndims, int64_t *
|
|||
CustomReorganizeKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<int64_t *>(input2),
|
||||
static_cast<float *>(output), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,9 +30,7 @@ extern "C" int CustomSquare(int nparam, void **params, int *ndims, int64_t **sha
|
|||
constexpr int TOTAL_PARAM_NUM = 2;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Any return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Any return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are one input and one output, so the nparam should be 2.
|
||||
|
@ -62,6 +60,6 @@ extern "C" int CustomSquare(int nparam, void **params, int *ndims, int64_t **sha
|
|||
// Do the computation
|
||||
CustomSquareKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(output), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,9 +31,7 @@ extern "C" int CustomSquareBprop(int nparam, void **params, int *ndims, int64_t
|
|||
constexpr int TOTAL_PARAM_NUM = 4;
|
||||
|
||||
// Users can add any check on their need. If check fails, user can return any value larger than 0 to safely exit.
|
||||
// Return value larger than 0 will cause mindspore to stop computing and safely exit.
|
||||
// Specially, return 1 will show log: "Number of parameters passed is inconsistent with what the user wants".
|
||||
// return 2 will show log: "Type of parameters passed is inconsistent with what the user wants".
|
||||
// Return value not equal to 0 will cause MindSpore to stop computing and safely exit.
|
||||
|
||||
// This is to check if the num of parameters the same as what the user wants.
|
||||
// There are three inputs and one output, so the nparam should be 4.
|
||||
|
@ -66,6 +64,6 @@ extern "C" int CustomSquareBprop(int nparam, void **params, int *ndims, int64_t
|
|||
CustomSquareBpropKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input3),
|
||||
static_cast<float *>(output), size);
|
||||
|
||||
// When return 0, mindspore will continue to run if this kernel could launch successfully.
|
||||
// When return 0, MindSpore will continue to run if this kernel could launch successfully.
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue