forked from mindspore-Ecosystem/mindspore
!24074 Support AOT Operator for GPU/CPU Backend
Merge pull request !24074 from jiaoy1224/pyfunc
This commit is contained in:
commit
19b04d3ff3
|
@ -31,6 +31,7 @@ if(ENABLE_CPU)
|
|||
"cpu/quantum/*.cc"
|
||||
"cpu/pyfunc/*.cc"
|
||||
"cpu/rl/*.cc"
|
||||
"cpu/custom/*.cc"
|
||||
)
|
||||
|
||||
if(NOT ENABLE_MPI)
|
||||
|
|
|
@ -176,5 +176,10 @@ std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool CPUKernelFactory::SearchRegisteredOp(const std::string &kernel_name) const {
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
return iter != name_to_attr_creator_.end();
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,7 @@ class CPUKernelFactory {
|
|||
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
|
||||
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
|
||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||
bool SearchRegisteredOp(const std::string &kernel_name) const;
|
||||
|
||||
private:
|
||||
CPUKernelFactory() = default;
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/custom/custom_aot_cpu_kernel.h"
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "abstract/utils.h"
|
||||
#include "runtime/device/cpu/cpu_common.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
CustomAOTCpuKernel::~CustomAOTCpuKernel() {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
if (handle_ != nullptr) {
|
||||
dlclose(handle_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
cuda_path_ = exec_info.substr(0, pos);
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;
|
||||
}
|
||||
|
||||
num_input_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_type_list = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
if (num_input_ != input_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input shapes'size is " << num_input_ << ", while input types' size is "
|
||||
<< input_type_list.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> in_shape_tmp;
|
||||
std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(in_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(input_type_list[i]));
|
||||
}
|
||||
|
||||
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
auto output_type_list = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
if (num_output_ != output_type_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Output shapes'size is " << num_output_ << ", while output types' size is "
|
||||
<< output_type_list.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> out_shape_tmp;
|
||||
std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(out_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(output_type_list[i]));
|
||||
}
|
||||
|
||||
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
[](auto &v) { return &v[0]; });
|
||||
std::transform(std::begin(type_list_), std::end(type_list_), std::back_inserter(type_pointer_list_),
|
||||
[](auto &str) { return str.c_str(); });
|
||||
}
|
||||
|
||||
bool CustomAOTCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::vector<void *> params;
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
params.push_back(GetDeviceAddress<void>(inputs, i));
|
||||
}
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
params.push_back(GetDeviceAddress<void>(outputs, i));
|
||||
}
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
if (!handle_) {
|
||||
handle_ = dlopen(cuda_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle_) {
|
||||
MS_LOG(EXCEPTION) << "Open Error: " << dlerror();
|
||||
}
|
||||
}
|
||||
|
||||
if (!aot_func_) {
|
||||
aot_func_ =
|
||||
reinterpret_cast<std::add_pointer<int(int, void **, int *, int64_t **, const char **, void *, void *)>::type>(
|
||||
dlsym(handle_, func_name_.c_str()));
|
||||
if (auto error_info = dlerror(); error_info != nullptr) {
|
||||
MS_LOG(EXCEPTION) << error_info;
|
||||
}
|
||||
}
|
||||
|
||||
int nparam = SizeToInt(params.size());
|
||||
int ret = 0;
|
||||
if (nparam == 0) {
|
||||
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
|
||||
} else {
|
||||
ret = aot_func_(nparam, ¶ms[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], nullptr, nullptr);
|
||||
}
|
||||
|
||||
switch (ret) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
MS_LOG(EXCEPTION) << "Number of parameters passed to AOT kernel is " << nparam
|
||||
<< ", inconsistent with what the user wants";
|
||||
case 2:
|
||||
MS_LOG(EXCEPTION) << "Type of parameters passed to AOT kernel is inconsistent with what the user wants";
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Error occurred when running AOT kernel, "
|
||||
<< "error id is " << ret;
|
||||
}
|
||||
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Custom AOT Operator doesn't support Windows currently";
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_AOT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_AOT_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class CustomAOTCpuKernel : public CPUKernel {
|
||||
public:
|
||||
CustomAOTCpuKernel() : num_input_(0), num_output_(0), handle_(nullptr), aot_func_(nullptr) {}
|
||||
~CustomAOTCpuKernel();
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<std::vector<int64_t>> shape_list_;
|
||||
std::vector<int> ndims_;
|
||||
std::vector<std::string> type_list_;
|
||||
|
||||
std::vector<int64_t *> shapes_;
|
||||
std::vector<const char *> type_pointer_list_;
|
||||
|
||||
size_t num_input_;
|
||||
size_t num_output_;
|
||||
std::string cuda_path_;
|
||||
std::string func_name_;
|
||||
void *handle_;
|
||||
int (*aot_func_)(int, void **, int *, int64_t **, const char **, void *, void *);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_AOT_KERNEL_H_
|
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUSTOM_CUSTOM_AOT_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUSTOM_CUSTOM_AOT_GPU_KERNEL_H
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CustomAOTGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CustomAOTGpuKernel() : num_input_(0), num_output_(0), handle_(nullptr), aot_func_(nullptr) {}
|
||||
~CustomAOTGpuKernel() override {
|
||||
if (handle_ != nullptr) {
|
||||
dlclose(handle_);
|
||||
}
|
||||
}
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
std::vector<void *> params;
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
params.push_back(GetDeviceAddress<void>(inputs, i));
|
||||
}
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
params.push_back(GetDeviceAddress<void>(outputs, i));
|
||||
}
|
||||
|
||||
if (!handle_) {
|
||||
handle_ = dlopen(cuda_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle_) {
|
||||
MS_LOG(ERROR) << "Open Error: " << dlerror();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!aot_func_) {
|
||||
aot_func_ =
|
||||
reinterpret_cast<std::add_pointer<int(int, void **, int *, int64_t **, const char **, void *, void *)>::type>(
|
||||
dlsym(handle_, func_name_.c_str()));
|
||||
if (auto error_info = dlerror(); error_info != nullptr) {
|
||||
MS_LOG(ERROR) << error_info;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int nparam = SizeToInt(params.size());
|
||||
int ret = 0;
|
||||
if (nparam == 0) {
|
||||
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, stream_ptr, nullptr);
|
||||
} else {
|
||||
ret = aot_func_(nparam, ¶ms[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], stream_ptr, nullptr);
|
||||
}
|
||||
|
||||
switch (ret) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
MS_LOG(ERROR) << "Number of parameters passed to AOT kernel is " << nparam
|
||||
<< ", inconsistent with what the user wants";
|
||||
return false;
|
||||
case 2:
|
||||
MS_LOG(ERROR) << "Type of parameters passed to AOT kernel is inconsistent with what the user wants";
|
||||
return false;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error occurred when running AOT kernel, "
|
||||
<< "error id is " << ret;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
cuda_path_ = exec_info.substr(0, pos);
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Wrong execute info:" << exec_info;
|
||||
return false;
|
||||
}
|
||||
|
||||
num_input_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_type_list = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
if (num_input_ != input_type_list.size()) {
|
||||
MS_LOG(ERROR) << "Input shapes'size is " << num_input_ << ", while input types' size is "
|
||||
<< input_type_list.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> in_shape_tmp;
|
||||
std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(in_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(input_type_list[i]));
|
||||
}
|
||||
|
||||
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
auto output_type_list = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
|
||||
if (num_output_ != output_type_list.size()) {
|
||||
MS_LOG(ERROR) << "Output shapes'size is " << num_output_ << ", while output types' size is "
|
||||
<< output_type_list.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_output_; i++) {
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> out_shape_tmp;
|
||||
std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(out_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(output_type_list[i]));
|
||||
}
|
||||
|
||||
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
[](auto &v) { return &v[0]; });
|
||||
std::transform(std::begin(type_list_), std::end(type_list_), std::back_inserter(type_pointer_list_),
|
||||
[](auto &str) { return str.c_str(); });
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
size_t this_size =
|
||||
LongToSize(std::accumulate(shape_list_[i].begin(), shape_list_[i].end(), 1, std::multiplies<int64_t>()));
|
||||
this_size *= GetDtypeNbyte(type_list_[i]);
|
||||
input_size_list_.push_back(this_size);
|
||||
}
|
||||
for (size_t i = num_input_; i < (num_input_ + num_output_); i++) {
|
||||
size_t this_size =
|
||||
LongToSize(std::accumulate(shape_list_[i].begin(), shape_list_[i].end(), 1, std::multiplies<int64_t>()));
|
||||
|
||||
this_size *= GetDtypeNbyte(type_list_[i]);
|
||||
output_size_list_.push_back(this_size);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
std::vector<std::vector<int64_t>> shape_list_;
|
||||
std::vector<int> ndims_;
|
||||
std::vector<std::string> type_list_;
|
||||
|
||||
std::vector<int64_t *> shapes_;
|
||||
std::vector<const char *> type_pointer_list_;
|
||||
|
||||
size_t num_input_;
|
||||
size_t num_output_;
|
||||
std::string cuda_path_;
|
||||
std::string func_name_;
|
||||
void *handle_;
|
||||
int (*aot_func_)(int, void **, int *, int64_t **, const char **, void *, void *);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUSTOM_CUSTOM_AOT_GPU_KERNEL_H
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/kernel_compiler/oplib/opinfo.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/kernel_compiler/cpu/custom/custom_aot_cpu_kernel.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -316,6 +317,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
// Select for dynamic kernel(both the number and data type are undetermined).
|
||||
const std::string &op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
|
||||
!kernel::CPUKernelFactory::GetInstance().SearchRegisteredOp(op_name)) {
|
||||
kernel::CPUKernelRegistrar(op_name, KernelAttr(), []() { return std::make_shared<kernel::CustomAOTCpuKernel>(); });
|
||||
}
|
||||
|
||||
if (IsDynamicParamKernel(op_name)) {
|
||||
return UpdateDynamicKernelBuildInfoAndAttrs(kernel_node);
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/kernel_compiler/oplib/opinfo.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/gpu/custom/custom_aot_gpu_kernel.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -463,6 +464,10 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
|||
builder->SetOutputsDeviceType(outputs_type);
|
||||
bool result = false;
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
|
||||
const auto &kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (!kernel::GpuKernelFactory::GetInstance().SearchRegistered(kernel_name, builder->Build())) {
|
||||
kernel::GpuKernelRegister(kernel_name, KernelAttr(), []() { return new kernel::CustomAOTGpuKernel(); });
|
||||
}
|
||||
// Custom op select kernel from OpLib
|
||||
result = SelectCustomKernel(kernel_node, builder->Build(), &kernel_type);
|
||||
} else if (kernel_type == UNKNOWN_KERNEL_TYPE) {
|
||||
|
|
|
@ -91,7 +91,7 @@ class CustomRegOp(RegOp):
|
|||
Please note that target and the `func_type` of `Custom` op have some constraints.
|
||||
If func_type is "akg", target can be one of ["Ascend", "GPU"].
|
||||
If func_type is "tbe", target can only be "Ascend".
|
||||
If func_type is "lib", target can only be "GPU".
|
||||
If func_type is "aot", target can only be "GPU".
|
||||
If func_type is "py_func", target can only be "CPU".
|
||||
Default: None.
|
||||
"""
|
||||
|
@ -137,10 +137,59 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
This is an experimental prototype that is subject to change.
|
||||
|
||||
Args:
|
||||
func (Union[function, str]): If func is of function type, then func should be a Python function which describes
|
||||
the computation logic of a user defined operator. The function can be one of the following:
|
||||
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
|
||||
2. A TBE operator implementation function.
|
||||
func (Union[function, str]):
|
||||
function:
|
||||
If func is of function type, then func should be a Python function which describes
|
||||
the computation logic of a user defined operator. The function can be one of the following:
|
||||
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
|
||||
2. A TBE operator implementation function.
|
||||
|
||||
str:
|
||||
If func is of str type, then str should be a path of binary file along with a function name. This could
|
||||
only be used when func_type is "aot". Currently "aot" supports GPU/CPU(linux only) platform.
|
||||
"aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as
|
||||
an operator. Users need compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time, and
|
||||
offer the path of the file along with a function name.
|
||||
|
||||
"xxx.so" file Generation:
|
||||
1) GPU Platform:
|
||||
Given user defined "xxx.cu" file (ex. "{path}/add.cu"),
|
||||
use nvcc command to compile it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
|
||||
2) CPU Platform:
|
||||
Given user defined "xxx.cc" file (ex. "{path}/add.cc"),
|
||||
use g++/gcc command to compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
|
||||
Define a "xxx.cc"/"xxx.cu" file:
|
||||
"aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share the same
|
||||
args. Typically, the function should be as:
|
||||
|
||||
int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra)
|
||||
|
||||
Parameters:
|
||||
nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3
|
||||
outputs, then nparam=5
|
||||
params(void **): a pointer to the array of inputs and outputs' pointer; the pointer type of
|
||||
inputs and outputs is void * ; suppose the operator has 2 inputs and 3 outputs, then the first
|
||||
input's pointer is nparam[0] and the second output's pointer is nparam[4]
|
||||
ndims(int *): a pointer to the array of inputs and outputs' dimension num; suppose params[i]
|
||||
is a 1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.
|
||||
shapes(int64_t **): a pointer to the array of inputs and outputs' shapes(int64_t *); the ith
|
||||
input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i]
|
||||
is a 2x3 tensor and params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.
|
||||
dtypes(const char **): a pointer to the array of inputs and outputs' types(const char *);
|
||||
(ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64",
|
||||
"uint", "uint8", "uint16", "uint32", "uint64", "bool")
|
||||
stream(void *): stream pointer, only used in cuda file
|
||||
extra(void *): used for further extension
|
||||
Return Value(int):
|
||||
0: raise no Exception
|
||||
larger than 0: will raise Exception
|
||||
Examples:
|
||||
see details tests/st/ops/graph_kernel/custom/aot_test_files/
|
||||
Use it in Custom:
|
||||
Custom(func="{path}/{file_name}:{func_name}",...)
|
||||
(ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_type=mstype.float32))
|
||||
|
||||
out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of func.
|
||||
If func has single output, then the value of output shape is a list.
|
||||
If func has multiple outputs, then the value of output shape is a tuple of list, each list represents the
|
||||
|
@ -149,8 +198,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
function or the value of output dtype of func.
|
||||
If func has single output, then the value of output shape is a mindspore.dtype.
|
||||
If func has multiple outputs, then the value of output shape is a tuple of mindspore.dtype.
|
||||
func_type (str): The implementation type of func, should be one of ["akg", "tbe", "lib", "py_func"].
|
||||
grad (function): The gradient function of func. Default: None.
|
||||
func_type (str): The implementation type of func, should be one of ["akg", "tbe", "aot", "py_func"].
|
||||
bprop (function): The gradient function of func. Default: None.
|
||||
reg_info (Union[str, dict, list, tuple]): Represents the registration information of func with json format of
|
||||
type str or dict. The registration information specifies supported formats of input and output, attributes
|
||||
and target of func. If reg_info is a list or tuple, then each item should be with json format of type str
|
||||
|
@ -177,54 +226,78 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>> from mindspore.ops.op_info_register import DataType
|
||||
>>> from mindspore.nn import Cell
|
||||
>>>
|
||||
>>> #func_type="tbe"
|
||||
>>>
|
||||
>>> square_with_bias_op_info = CustomRegOp() \
|
||||
>>> .fusion_type("OPAQUE") \
|
||||
>>> .attr("bias", "required", "float") \
|
||||
>>> .input(0, "x") \
|
||||
>>> .output(0, "y") \
|
||||
>>> .dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
>>> .dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
>>> .target("Ascend") \
|
||||
>>> .get_op_info()
|
||||
... .fusion_type("OPAQUE") \
|
||||
... .attr("bias", "required", "float") \
|
||||
... .input(0, "x") \
|
||||
... .output(0, "y") \
|
||||
... .dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
... .dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
... .target("Ascend") \
|
||||
... .get_op_info()
|
||||
>>>
|
||||
>>> @custom_op_info_register(square_with_bias_op_info)
|
||||
>>> def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"):
|
||||
>>> import te.lang.cce
|
||||
>>> from te import tvm
|
||||
>>> from topi import generic
|
||||
>>> from topi.cce import util
|
||||
>>>
|
||||
>>> shape = input_x.get("shape")
|
||||
>>> dtype = input_x.get("dtype").lower()
|
||||
>>>
|
||||
>>> shape = util.shape_refine(shape)
|
||||
>>> data = tvm.placeholder(shape, name="data", dtype=dtype.lower())
|
||||
>>>
|
||||
>>> with tvm.target.cce():
|
||||
>>> res0 = te.lang.cce.vmul(data, data)
|
||||
>>> res = te.lang.cce.vadds(res0, bias)
|
||||
>>> sch = generic.auto_schedule(res)
|
||||
>>>
|
||||
>>> config = {"print_ir": False,
|
||||
>>> "name": kernel_name,
|
||||
>>> "tensor_list": [data, res]}
|
||||
>>>
|
||||
>>> te.lang.cce.cce_build_code(sch, config)
|
||||
... def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"):
|
||||
... import te.lang.cce
|
||||
... from te import tvm
|
||||
... from topi import generic
|
||||
... from topi.cce import util
|
||||
...
|
||||
... shape = input_x.get("shape")
|
||||
... dtype = input_x.get("dtype").lower()
|
||||
...
|
||||
... shape = util.shape_refine(shape)
|
||||
... data = tvm.placeholder(shape, name="data", dtype=dtype.lower())
|
||||
...
|
||||
... with tvm.target.cce():
|
||||
... res0 = te.lang.cce.vmul(data, data)
|
||||
... res = te.lang.cce.vadds(res0, bias)
|
||||
... sch = generic.auto_schedule(res)
|
||||
...
|
||||
... config = {"print_ir": False,
|
||||
... "name": kernel_name,
|
||||
... "tensor_list": [data, res]}
|
||||
...
|
||||
... te.lang.cce.cce_build_code(sch, config)
|
||||
>>>
|
||||
>>> class Net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net1, self).__init__()
|
||||
>>> self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
|
||||
>>> func_type="tbe")
|
||||
... def __init__(self):
|
||||
... super(Net1, self).__init__()
|
||||
... self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
|
||||
... func_type="tbe")
|
||||
...
|
||||
... def construct(self, x):
|
||||
... res = self.square_with_bias(x, 1.0)
|
||||
... return res
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> res = self.square_with_bias(x, 1.0)
|
||||
>>> return res
|
||||
>>> #func_type="aot", platform=GPU
|
||||
>>>
|
||||
>>> class AOTSingleOutputNet(Cell):
|
||||
... def __init__(self, func, out_shapes, out_types, reg=None):
|
||||
... super(AOTSingleOutputNet, self).__init__()
|
||||
... self.program = Custom(func, out_shapes, out_types, "aot", reg_info=reg)
|
||||
... def construct(self, x, y):
|
||||
... return self.program(x, y)
|
||||
>>>
|
||||
>>> reorganize_op_info = CustomRegOp() \
|
||||
... .fusion_type("OPAQUE") \
|
||||
... .input(0, "x1") \
|
||||
... .input(1, "x2") \
|
||||
... .output(0, "y") \
|
||||
... .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
... .target("GPU") \
|
||||
... .get_op_info()
|
||||
>>>
|
||||
>>> #test = AOTSingleOutputNet("./reorganize.so:CustomReorganize", shape, mstype.float32, reorganize_gpu_info)
|
||||
>>> #output = test(Tensor(input_x), Tensor(input_y))
|
||||
>>> #see more details in tests/st/ops/graph_kernel/custom/test_custom_aot.py
|
||||
"""
|
||||
|
||||
registered_func = {}
|
||||
|
||||
def __init__(self, func, out_shape, out_dtype, func_type, grad=None, reg_info=None):
|
||||
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
|
||||
ops.PrimitiveWithInfer.__init__(self, "Custom")
|
||||
|
||||
self.supported_targets = ["Ascend", "GPU", "CPU"]
|
||||
|
@ -242,7 +315,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.add_prim_attr("func_name", self.func_name)
|
||||
self.out_shape = out_shape
|
||||
self.out_dtype = out_dtype
|
||||
self.grad = grad
|
||||
self.bprop = bprop
|
||||
self.func_type = func_type
|
||||
# Register info
|
||||
self.register_info(reg_info)
|
||||
|
@ -272,7 +345,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
return self.out_dtype
|
||||
|
||||
def get_bprop(self):
|
||||
return self.grad
|
||||
return self.bprop
|
||||
|
||||
def register_info(self, info):
|
||||
"""Register reg_info."""
|
||||
|
@ -392,7 +465,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
reg_info["imply_type"].strip():
|
||||
return reg_info["imply_type"]
|
||||
# Infer imply_type from func_type
|
||||
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "lib": target, "py_func": target}
|
||||
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "py_func": target}
|
||||
return func_type_to_imply_type.get(self.func_type, "AKG")
|
||||
|
||||
def add_inputs_name_to_attr(self, reg_info):
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string.h>
|
||||
using size_t = decltype(sizeof(int));
|
||||
using int64_t = decltype(sizeof(long));
|
||||
|
||||
extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra) {
|
||||
if (nparam != 3) return 1;
|
||||
float *input1 = static_cast<float *>(params[0]);
|
||||
float *input2 = static_cast<float *>(params[1]);
|
||||
float *output = static_cast<float *>(params[2]);
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[2]; i++) {
|
||||
size *= shapes[2][i];
|
||||
}
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = input1[i] + input2[i];
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define THREADS 1024
|
||||
__global__ void CustomAddKernel(float *input1, float *input2, float *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output[idx] = input1[idx] + input2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomAdd(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 3) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input2 = params[1];
|
||||
void *output = params[2];
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[2]; i++) {
|
||||
size *= shapes[2][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
CustomAddKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input2),
|
||||
static_cast<float *>(output), size);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define THREADS 1024
|
||||
__global__ void CustomAddMulDivKernel(float *input1, float *input2, float *output1, float *output2, float *output3,
|
||||
size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output1[idx] = input1[idx] + input2[idx];
|
||||
output2[idx] = input1[idx] * input2[idx];
|
||||
output3[idx] = input1[idx] / input2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomAddMulDiv(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 5) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input2 = params[1];
|
||||
void *output1 = params[2];
|
||||
void *output2 = params[3];
|
||||
void *output3 = params[4];
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[2]; i++) {
|
||||
size *= shapes[2][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
CustomAddMulDivKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input2),
|
||||
static_cast<float *>(output1), static_cast<float *>(output2),
|
||||
static_cast<float *>(output3), size);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define THREADS 1024
|
||||
|
||||
__global__ void CustomAddMulDivBpropKernel(float *input1, float *input2, float *input3, float *input4, float *input5,
|
||||
float *output1, float *output2, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output1[idx] = input3[idx] + input4[idx] * input2[idx] + input5[idx] / input2[idx];
|
||||
output2[idx] = input3[idx] + input4[idx] * input1[idx] - input5[idx] * input1[idx] / input2[idx] / input2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomAddMulDivBprop(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 7) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input2 = params[1];
|
||||
void *input3 = params[2];
|
||||
void *input4 = params[3];
|
||||
void *input5 = params[4];
|
||||
void *output1 = params[5];
|
||||
void *output2 = params[6];
|
||||
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[6]; i++) {
|
||||
size *= shapes[6][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
CustomAddMulDivBpropKernel<<<n + 1, THREADS, 0, custream>>>(
|
||||
static_cast<float *>(input1), static_cast<float *>(input2), static_cast<float *>(input3),
|
||||
static_cast<float *>(input4), static_cast<float *>(input5), static_cast<float *>(output1),
|
||||
static_cast<float *>(output2), size);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#define THREADS 1024
|
||||
|
||||
__global__ void CustomHSquareMulKernel(float *input1, half *input2, half *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output[idx] = __float2half(input1[idx] * input1[idx] * __half2float(input2[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomHSquareMul(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 3) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input2 = params[1];
|
||||
|
||||
void *output = params[2];
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[2]; i++) {
|
||||
size *= shapes[2][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
|
||||
if (strcmp(dtypes[0], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
if (strcmp(dtypes[1], "float16") != 0) {
|
||||
return 2;
|
||||
}
|
||||
if (strcmp(dtypes[2], "float16") != 0) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
CustomHSquareMulKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<half *>(input2),
|
||||
static_cast<half *>(output), size);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define THREADS 1024
|
||||
__global__ void CustomReorganizeKernel(float *input1, int64_t *input2, float *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output[idx] = input1[input2[idx]];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomReorganize(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 3) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input2 = params[1];
|
||||
|
||||
void *output = params[2];
|
||||
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[2]; i++) {
|
||||
size *= shapes[2][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
|
||||
if (strcmp(dtypes[0], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
if (strcmp(dtypes[1], "int64") != 0) {
|
||||
return 2;
|
||||
}
|
||||
if (strcmp(dtypes[2], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
CustomReorganizeKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<int64_t *>(input2),
|
||||
static_cast<float *>(output), size);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#define THREADS 1024
|
||||
__global__ void CustomSquareKernel(float *input1, float *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output[idx] = input1[idx] * input1[idx];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomSquare(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
|
||||
void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 2) return 1;
|
||||
void *input1 = params[0];
|
||||
void *output = params[1];
|
||||
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[1]; i++) {
|
||||
size *= shapes[1][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
CustomSquareKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(output), size);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define THREADS 1024
|
||||
__global__ void CustomSquareBpropKernel(float *input1, float *input3, float *output, size_t size) {
|
||||
auto idx = blockIdx.x * THREADS + threadIdx.x;
|
||||
if (idx < size) {
|
||||
output[idx] = input1[idx] * input3[idx] * 2;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int CustomSquareBprop(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra) {
|
||||
cudaStream_t custream = static_cast<cudaStream_t>(stream);
|
||||
if (nparam != 4) return 1;
|
||||
void *input1 = params[0];
|
||||
void *input3 = params[2];
|
||||
void *output = params[3];
|
||||
|
||||
size_t size = 1;
|
||||
|
||||
for (int i = 0; i < ndims[3]; i++) {
|
||||
size *= shapes[3][i];
|
||||
}
|
||||
int n = size / THREADS;
|
||||
for (int i = 0; i < nparam; i++) {
|
||||
if (strcmp(dtypes[i], "float32") != 0) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
CustomSquareBpropKernel<<<n + 1, THREADS, 0, custream>>>(static_cast<float *>(input1), static_cast<float *>(input3),
|
||||
static_cast<float *>(output), size);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,404 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import platform
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import context, ops, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.op_info_register import DataType
|
||||
from mindspore.ops.operations.custom_ops import Custom, CustomRegOp
|
||||
|
||||
|
||||
class AOTSingleOutputNet(Cell):
|
||||
def __init__(self, func, out_shapes, out_types, reg=None):
|
||||
super(AOTSingleOutputNet, self).__init__()
|
||||
|
||||
self.program = Custom(func, out_shapes, out_types, "aot", reg_info=reg)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program(x, y)
|
||||
|
||||
|
||||
def get_file_path_gpu(cuda, so):
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
cmd = "nvcc --shared -Xcompiler -fPIC -o "+dir_path+"/aot_test_files/"+so+" "+dir_path+"/aot_test_files/"+cuda
|
||||
func_path = dir_path+"/aot_test_files/"+so
|
||||
return cmd, func_path
|
||||
|
||||
|
||||
def get_file_path_cpu(cc, so):
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
cmd = "g++ --shared -fPIC -o "+dir_path+"/aot_test_files/"+so+" "+dir_path+"/aot_test_files/"+cc
|
||||
func_path = dir_path+"/aot_test_files/"+so
|
||||
return cmd, func_path
|
||||
|
||||
|
||||
def check_exec_file(cmd, func_path, source, execf):
|
||||
with os.popen(cmd) as f:
|
||||
r = f.read()
|
||||
if os.path.exists(func_path) and not r:
|
||||
pass
|
||||
else:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
assert False, "Failed to compile " + source+" to "+execf
|
||||
|
||||
|
||||
def aot_single_output(get_file_path, source, execf, reg):
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
cmd, func_path = get_file_path(source, execf)
|
||||
check_exec_file(cmd, func_path, source, execf)
|
||||
try:
|
||||
test = AOTSingleOutputNet(func_path+":CustomAdd", (shape,), (mstype.float32,), reg)
|
||||
output = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
add_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
@ pytest.mark.env_onecard
|
||||
def test_aot_single_output_gpu():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs, single output, GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
aot_single_output(get_file_path_gpu, "add.cu", "add.so", add_gpu_info)
|
||||
|
||||
|
||||
add_cpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("CPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_cpu
|
||||
@ pytest.mark.env_onecard
|
||||
def test_aot_single_output_cpu():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs, single output, CPU
|
||||
Description: pre-compile xxx.cc to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
sys = platform.system()
|
||||
if sys == 'Windows':
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
aot_single_output(get_file_path_cpu, "add.cc", "add.so", add_cpu_info)
|
||||
|
||||
|
||||
reorganize_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
@ pytest.mark.env_onecard
|
||||
def test_reorganize():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs(dtypes:float32,int64_t), single output, GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
shape = [5]
|
||||
input_x = np.array([1.0, 4.0, 9.0, 16.0, 25.0]).astype(np.float32)
|
||||
input_y = np.array([3, 2, 0, 1, 4]).astype(np.int64)
|
||||
expect = np.array([16.0, 9.0, 1.0, 4.0, 25.0]).astype(np.float32)
|
||||
|
||||
cmd, func_path = get_file_path_gpu("reorganize.cu", "reorganize.so")
|
||||
check_exec_file(cmd, func_path, "reorganize.cu", "reorganize.so")
|
||||
try:
|
||||
test = AOTSingleOutputNet(func_path+":CustomReorganize", (shape,), (mstype.float32,), reorganize_gpu_info)
|
||||
output = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
hetero_square_mul_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
@ pytest.mark.env_onecard
|
||||
def test_hetero_square_mul():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs(dtypes:float32,float16), single output(dtype:float16), GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
shape = [5]
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
expect = (input_x * input_x * input_y.astype(np.float32)).astype(np.float16)
|
||||
cmd, func_path = get_file_path_gpu("hetero_square_mul.cu", "hetero_square_mul.so")
|
||||
check_exec_file(cmd, func_path, "hetero_square_mul.cu", "hetero_square_mul.so")
|
||||
try:
|
||||
test = AOTSingleOutputNet(func_path+":CustomHSquareMul", (shape,),
|
||||
(mstype.float16,), hetero_square_mul_gpu_info)
|
||||
output = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
class SquareGradNet(Cell):
|
||||
def __init__(self, func, out_shapes, out_types, bprop, reg):
|
||||
super(SquareGradNet, self).__init__()
|
||||
self.square = Custom(func, out_shapes, out_types, "aot", bprop, reg)
|
||||
|
||||
def construct(self, x):
|
||||
res = self.square(x)[0]
|
||||
res2 = self.square(res)[0]
|
||||
return res2
|
||||
|
||||
|
||||
square_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
square_bprop_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.input(2, "x3") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
@ pytest.mark.env_onecard
|
||||
def test_square_py_bprop():
|
||||
"""
|
||||
Feature: custom aot operator, bprop(pyfunc), GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
|
||||
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
expect = np.array([4.0, 256.0, 2916.0]).astype(np.float32)
|
||||
cmd, func_path = get_file_path_gpu("square.cu", "square_py.so")
|
||||
check_exec_file(cmd, func_path, "square.cu", "square_py.so")
|
||||
|
||||
def bprop(x, out, dout):
|
||||
gradient = x * 2
|
||||
dx = gradient * dout
|
||||
return (dx,)
|
||||
try:
|
||||
net = SquareGradNet(func_path+":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=square_gpu_info)
|
||||
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
dx_np = dx.asnumpy()
|
||||
assert np.allclose(expect, dx_np, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
@ pytest.mark.env_onecard
|
||||
def test_square_aot_bprop():
|
||||
"""
|
||||
Feature: custom aot operator, bprop(Cell), GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
|
||||
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
expect = np.array([4.0, 256.0, 2916.0]).astype(np.float32)
|
||||
cmd_bprop, func_path_bprop = get_file_path_gpu("square_bprop.cu", "square_bprop.so")
|
||||
check_exec_file(cmd_bprop, func_path_bprop, "square_bprop.cu", "square_bprop.so")
|
||||
try:
|
||||
aot_bprop = Custom(func_path_bprop+":CustomSquareBprop",
|
||||
([3],), (mstype.float32,), "aot", reg_info=square_bprop_gpu_info)
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path_bprop):
|
||||
os.remove(func_path_bprop)
|
||||
raise e
|
||||
|
||||
def bprop(x, out, dout):
|
||||
res = aot_bprop(x, out, dout)
|
||||
return res
|
||||
|
||||
cmd, func_path = get_file_path_gpu("square.cu", "square.so")
|
||||
check_exec_file(cmd, func_path, "square_bprop.cu", "square_bprop.so")
|
||||
try:
|
||||
net = SquareGradNet(func_path+":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=square_gpu_info)
|
||||
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
if os.path.exists(func_path_bprop):
|
||||
os.remove(func_path_bprop)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
os.remove(func_path_bprop)
|
||||
dx_np = dx.asnumpy()
|
||||
assert np.allclose(expect, dx_np, 0.0001, 0.0001)
|
||||
|
||||
|
||||
class AOTMultiOutputNet(Cell):
|
||||
def __init__(self, func, out_shapes, out_types, bprop=None, reg=None):
|
||||
super(AOTMultiOutputNet, self).__init__()
|
||||
|
||||
self.program = Custom(func, out_shapes, out_types, "aot", bprop, reg)
|
||||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x, y):
|
||||
aot = self.program(x, y)
|
||||
add_res = self.add(aot[0], aot[1])
|
||||
mul_res = self.mul(add_res, aot[2])
|
||||
return mul_res
|
||||
|
||||
|
||||
multioutput_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y1") \
|
||||
.output(1, "y2") \
|
||||
.output(2, "y3") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
multioutput_bprop_gpu_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.input(2, "x3") \
|
||||
.input(3, "x4") \
|
||||
.input(4, "x5") \
|
||||
.output(0, "y1") \
|
||||
.output(1, "y2") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("GPU") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.env_onecard
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
def test_add_mul_div_bprop():
|
||||
"""
|
||||
Feature: custom aot operator, bprop(Cell), multiple outputs, GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
|
||||
y = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
expect_dx = np.array([5.0, 17.0, 37.0]).astype(np.float32)
|
||||
expect_dy = np.array([-1.0, -16.0, -81.0]).astype(np.float32)
|
||||
|
||||
cmd_bprop, func_path_bprop = get_file_path_gpu("add_mul_div_bprop.cu", "add_mul_div_bprop.so")
|
||||
check_exec_file(cmd_bprop, func_path_bprop, "add_mul_div_bprop.cu", "add_mul_div_bprop.so")
|
||||
try:
|
||||
aot_bprop = Custom(func_path_bprop+":CustomAddMulDivBprop",
|
||||
([3], [3]), (mstype.float32, mstype.float32), "aot", reg_info=multioutput_bprop_gpu_info)
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path_bprop):
|
||||
os.remove(func_path_bprop)
|
||||
raise e
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
res = aot_bprop(x, y, dout[0], dout[1], dout[2])
|
||||
return res
|
||||
|
||||
cmd, func_path = get_file_path_gpu("add_mul_div.cu", "add_mul_div.so")
|
||||
check_exec_file(cmd, func_path, "add_mul_div.cu", "add_mul_div.so")
|
||||
try:
|
||||
net = AOTMultiOutputNet(func_path+":CustomAddMulDiv", ([3], [3], [3]),
|
||||
(mstype.float32, mstype.float32, mstype.float32), bprop=bprop, reg=multioutput_gpu_info)
|
||||
|
||||
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(x), Tensor(y), Tensor(sens))
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
os.remove(func_path)
|
||||
if os.path.exists(func_path_bprop):
|
||||
os.remove(func_path_bprop)
|
||||
raise e
|
||||
os.remove(func_path)
|
||||
os.remove(func_path_bprop)
|
||||
dx_np = dx.asnumpy()
|
||||
dy_np = dy.asnumpy()
|
||||
assert np.allclose(expect_dx, dx_np, 0.0001, 0.0001)
|
||||
assert np.allclose(expect_dy, dy_np, 0.0001, 0.0001)
|
|
@ -136,6 +136,7 @@ def add_n_with_bias(inputs, output, bias, kernel_name="add_n_with_bias"):
|
|||
|
||||
class Net1(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
# TBE dsl with attr
|
||||
|
@ -203,9 +204,10 @@ def bprop(data, axis, out, dout):
|
|||
|
||||
class Net2(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.square_with_bias = Custom(square_with_bias, out_shape=[3], out_dtype=mstype.float32, grad=bprop,
|
||||
self.square_with_bias = Custom(square_with_bias, out_shape=[3], out_dtype=mstype.float32, bprop=bprop,
|
||||
func_type="tbe")
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -217,10 +219,10 @@ class Net2(Cell):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad():
|
||||
def test_bprop():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for grad function of Custom op.
|
||||
Description: test cases for bprop function of Custom op.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
|
Loading…
Reference in New Issue