forked from mindspore-Ecosystem/mindspore
move GetCreator to src
This commit is contained in:
parent
e1a9dccac6
commit
3342022c3f
|
@ -27,12 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
/// \brief CapabilityParam defined performance of op when running.
|
||||
struct MS_API CapabilityParam {
|
||||
float exec_time_; /**< op running time argument */
|
||||
float power_usage_; /**< op power waste argument */
|
||||
};
|
||||
|
||||
/// \brief KernelInterface defined customized op's interface, such as infershape, and so on.
|
||||
class MS_API KernelInterface {
|
||||
public:
|
||||
|
@ -50,18 +44,6 @@ class MS_API KernelInterface {
|
|||
const schema::Primitive *primitive) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// \brief Method to get performance of an op when running.
|
||||
///
|
||||
/// \param[in] tensor_in Define the input tensors of op.
|
||||
/// \param[in] primitive Define the attributes of op.
|
||||
/// \param[in] param Define the contr of performance.
|
||||
///
|
||||
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h.
|
||||
virtual int GetCapability(const std::vector<mindspore::MSTensor> &tensor_in, const schema::Primitive *primitive,
|
||||
CapabilityParam *param) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
|
||||
|
|
|
@ -29,26 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
/// \brief KernelDesc defined kernel's basic attribute.
|
||||
struct MS_API KernelDesc {
|
||||
TypeId data_type; /**< kernel data type argument */
|
||||
int type; /**< op type argument */
|
||||
std::string arch; /**< deviceType argument */
|
||||
std::string provider; /**< user identification argument */
|
||||
|
||||
bool operator<(const KernelDesc &dst) const {
|
||||
if (provider != dst.provider) {
|
||||
return provider < dst.provider;
|
||||
} else if (arch != dst.arch) {
|
||||
return arch < dst.arch;
|
||||
} else if (data_type != dst.data_type) {
|
||||
return data_type < dst.data_type;
|
||||
} else {
|
||||
return type < dst.type;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief CreateKernel Defined a functor to create a kernel.
|
||||
///
|
||||
/// \param[in] inputs Define input tensors of kernel.
|
||||
|
@ -87,14 +67,6 @@ class MS_API RegisterKernel {
|
|||
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
|
||||
static int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator);
|
||||
|
||||
/// \brief Static methon to get a kernel's create function.
|
||||
///
|
||||
/// \param[in] desc Define kernel's basic attribute.
|
||||
/// \param[in] primitive Define the attributes of op.
|
||||
///
|
||||
/// \return Function pointer to create a kernel.
|
||||
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
|
||||
};
|
||||
|
||||
/// \brief KernelReg Defined registration class of kernel.
|
||||
|
|
|
@ -139,6 +139,7 @@ set(LITE_SRC
|
|||
${LITE_DIR}/src/registry/kernel_interface.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/registry/register_kernel.cc
|
||||
${LITE_DIR}/src/registry/register_utils.cc
|
||||
${LITE_DIR}/src/registry/register_kernel_impl.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/ms_tensor.cc
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
#include "src/registry/register_utils.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
@ -138,7 +139,7 @@ int KernelRegistry::GetCustomKernel(const std::vector<Tensor *> &in_tensors, con
|
|||
MS_ASSERT(kernel != nullptr);
|
||||
kernel::KernelDesc desc;
|
||||
KernelKeyToKernelDesc(key, &desc);
|
||||
CreateKernel creator = kernel::RegisterKernel::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
|
||||
CreateKernel creator = kernel::RegisterUtils::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
|
||||
if (creator == nullptr) {
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
|
|
@ -29,9 +29,5 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
|
|||
CreateKernel creator) {
|
||||
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
|
||||
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
|
||||
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include "include/registry/register_kernel.h"
|
||||
#include "src/registry/register_utils.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
using mindspore::schema::PrimitiveType_MIN;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/registry/register_utils.h"
|
||||
#include "src/registry/register_kernel_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
CreateKernel RegisterUtils::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
|
||||
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
|
||||
#include <string>
|
||||
#include "include/registry/register_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
/// \brief KernelDesc defined kernel's basic attribute.
|
||||
struct KernelDesc {
|
||||
TypeId data_type; /**< kernel data type argument */
|
||||
int type; /**< op type argument */
|
||||
std::string arch; /**< deviceType argument */
|
||||
std::string provider; /**< user identification argument */
|
||||
|
||||
bool operator<(const KernelDesc &dst) const {
|
||||
if (provider != dst.provider) {
|
||||
return provider < dst.provider;
|
||||
} else if (arch != dst.arch) {
|
||||
return arch < dst.arch;
|
||||
} else if (data_type != dst.data_type) {
|
||||
return data_type < dst.data_type;
|
||||
} else {
|
||||
return type < dst.type;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief RegisterKernel Defined registration of kernel.
|
||||
class RegisterUtils {
|
||||
public:
|
||||
/// \brief Static methon to get a kernel's create function.
|
||||
///
|
||||
/// \param[in] desc Define kernel's basic attribute.
|
||||
/// \param[in] primitive Define the attributes of op.
|
||||
///
|
||||
/// \return Function pointer to create a kernel.
|
||||
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
|
|
@ -141,6 +141,8 @@ set(LITE_SRC
|
|||
${SRC_DIR}/ms_tensor.cc
|
||||
${SRC_DIR}/tensorlist.cc
|
||||
${SRC_DIR}/registry/kernel_interface_registry.cc
|
||||
${SRC_DIR}/registry/register_utils.cc
|
||||
${SRC_DIR}/registry/register_kernel_impl.cc
|
||||
${SRC_DIR}/registry/kernel_interface.cc
|
||||
${SRC_DIR}/kernel_registry.cc
|
||||
${SRC_DIR}/inner_kernel.cc
|
||||
|
|
Loading…
Reference in New Issue