move GetCreator to src

This commit is contained in:
chenjianping 2021-08-04 10:29:57 +08:00
parent e1a9dccac6
commit 3342022c3f
9 changed files with 90 additions and 51 deletions

View File

@ -27,12 +27,6 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
/// \brief CapabilityParam defined performance of op when running.
struct MS_API CapabilityParam {
float exec_time_; /**< op running time argument */
float power_usage_; /**< op power waste argument */
};
/// \brief KernelInterface defined customized op's interface, such as infershape, and so on. /// \brief KernelInterface defined customized op's interface, such as infershape, and so on.
class MS_API KernelInterface { class MS_API KernelInterface {
public: public:
@ -50,18 +44,6 @@ class MS_API KernelInterface {
const schema::Primitive *primitive) { const schema::Primitive *primitive) {
return 0; return 0;
} }
/// \brief Method to get performance of an op when running.
///
/// \param[in] tensor_in Define the input tensors of op.
/// \param[in] primitive Define the attributes of op.
/// \param[in] param Define the contr of performance.
///
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h.
virtual int GetCapability(const std::vector<mindspore::MSTensor> &tensor_in, const schema::Primitive *primitive,
CapabilityParam *param) {
return 0;
}
}; };
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface. /// \brief KernelInterfaceCreator defined a functor to create KernelInterface.

View File

@ -29,26 +29,6 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
/// \brief KernelDesc defined kernel's basic attribute.
struct MS_API KernelDesc {
TypeId data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::string arch; /**< deviceType argument */
std::string provider; /**< user identification argument */
bool operator<(const KernelDesc &dst) const {
if (provider != dst.provider) {
return provider < dst.provider;
} else if (arch != dst.arch) {
return arch < dst.arch;
} else if (data_type != dst.data_type) {
return data_type < dst.data_type;
} else {
return type < dst.type;
}
}
};
/// \brief CreateKernel Defined a functor to create a kernel. /// \brief CreateKernel Defined a functor to create a kernel.
/// ///
/// \param[in] inputs Define input tensors of kernel. /// \param[in] inputs Define input tensors of kernel.
@ -87,14 +67,6 @@ class MS_API RegisterKernel {
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h. /// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, static int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
const std::string &type, CreateKernel creator); const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function.
///
/// \param[in] desc Define kernel's basic attribute.
/// \param[in] primitive Define the attributes of op.
///
/// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
}; };
/// \brief KernelReg Defined registration class of kernel. /// \brief KernelReg Defined registration class of kernel.

View File

@ -139,6 +139,7 @@ set(LITE_SRC
${LITE_DIR}/src/registry/kernel_interface.cc ${LITE_DIR}/src/registry/kernel_interface.cc
${LITE_DIR}/src/registry/kernel_interface_registry.cc ${LITE_DIR}/src/registry/kernel_interface_registry.cc
${LITE_DIR}/src/registry/register_kernel.cc ${LITE_DIR}/src/registry/register_kernel.cc
${LITE_DIR}/src/registry/register_utils.cc
${LITE_DIR}/src/registry/register_kernel_impl.cc ${LITE_DIR}/src/registry/register_kernel_impl.cc
${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/ms_tensor.cc ${LITE_DIR}/src/ms_tensor.cc

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "include/registry/register_kernel.h" #include "include/registry/register_kernel.h"
#include "src/registry/register_utils.h"
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "src/common/version_manager.h" #include "src/common/version_manager.h"
#include "nnacl/pooling_parameter.h" #include "nnacl/pooling_parameter.h"
@ -138,7 +139,7 @@ int KernelRegistry::GetCustomKernel(const std::vector<Tensor *> &in_tensors, con
MS_ASSERT(kernel != nullptr); MS_ASSERT(kernel != nullptr);
kernel::KernelDesc desc; kernel::KernelDesc desc;
KernelKeyToKernelDesc(key, &desc); KernelKeyToKernelDesc(key, &desc);
CreateKernel creator = kernel::RegisterKernel::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc); CreateKernel creator = kernel::RegisterUtils::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
if (creator == nullptr) { if (creator == nullptr) {
return RET_NOT_SUPPORT; return RET_NOT_SUPPORT;
} }

View File

@ -29,9 +29,5 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
CreateKernel creator) { CreateKernel creator) {
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
} }
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include <set> #include <set>
#include "include/registry/register_kernel.h" #include "include/registry/register_kernel.h"
#include "src/registry/register_utils.h"
using mindspore::schema::PrimitiveType_MAX; using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN; using mindspore::schema::PrimitiveType_MIN;

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/registry/register_utils.h"
#include "src/registry/register_kernel_impl.h"
namespace mindspore {
namespace kernel {
CreateKernel RegisterUtils::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
#define MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
#include <string>
#include "include/registry/register_kernel.h"
#include "schema/model_generated.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace kernel {
/// \brief KernelDesc defined kernel's basic attribute.
struct KernelDesc {
TypeId data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::string arch; /**< deviceType argument */
std::string provider; /**< user identification argument */
bool operator<(const KernelDesc &dst) const {
if (provider != dst.provider) {
return provider < dst.provider;
} else if (arch != dst.arch) {
return arch < dst.arch;
} else if (data_type != dst.data_type) {
return data_type < dst.data_type;
} else {
return type < dst.type;
}
}
};
/// \brief RegisterKernel Defined registration of kernel.
class RegisterUtils {
public:
/// \brief Static methon to get a kernel's create function.
///
/// \param[in] desc Define kernel's basic attribute.
/// \param[in] primitive Define the attributes of op.
///
/// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_

View File

@ -141,6 +141,8 @@ set(LITE_SRC
${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc ${SRC_DIR}/tensorlist.cc
${SRC_DIR}/registry/kernel_interface_registry.cc ${SRC_DIR}/registry/kernel_interface_registry.cc
${SRC_DIR}/registry/register_utils.cc
${SRC_DIR}/registry/register_kernel_impl.cc
${SRC_DIR}/registry/kernel_interface.cc ${SRC_DIR}/registry/kernel_interface.cc
${SRC_DIR}/kernel_registry.cc ${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/inner_kernel.cc ${SRC_DIR}/inner_kernel.cc