move GetCreator to src

This commit is contained in:
chenjianping 2021-08-04 10:29:57 +08:00
parent e1a9dccac6
commit 3342022c3f
9 changed files with 90 additions and 51 deletions

View File

@ -27,12 +27,6 @@
namespace mindspore {
namespace kernel {
/// \brief CapabilityParam defined performance of op when running.
struct MS_API CapabilityParam {
float exec_time_; /**< op running time argument */
float power_usage_; /**< op power waste argument */
};
/// \brief KernelInterface defined customized op's interface, such as infershape, and so on.
class MS_API KernelInterface {
public:
@ -50,18 +44,6 @@ class MS_API KernelInterface {
const schema::Primitive *primitive) {
return 0;
}
/// \brief Method to get performance of an op when running.
///
/// \param[in] tensor_in Define the input tensors of op.
/// \param[in] primitive Define the attributes of op.
/// \param[in] param Define the contr of performance.
///
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h.
virtual int GetCapability(const std::vector<mindspore::MSTensor> &tensor_in, const schema::Primitive *primitive,
CapabilityParam *param) {
return 0;
}
};
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.

View File

@ -29,26 +29,6 @@
namespace mindspore {
namespace kernel {
/// \brief KernelDesc defined kernel's basic attribute.
struct MS_API KernelDesc {
TypeId data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::string arch; /**< deviceType argument */
std::string provider; /**< user identification argument */
bool operator<(const KernelDesc &dst) const {
if (provider != dst.provider) {
return provider < dst.provider;
} else if (arch != dst.arch) {
return arch < dst.arch;
} else if (data_type != dst.data_type) {
return data_type < dst.data_type;
} else {
return type < dst.type;
}
}
};
/// \brief CreateKernel Defined a functor to create a kernel.
///
/// \param[in] inputs Define input tensors of kernel.
@ -87,14 +67,6 @@ class MS_API RegisterKernel {
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function.
///
/// \param[in] desc Define kernel's basic attribute.
/// \param[in] primitive Define the attributes of op.
///
/// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
};
/// \brief KernelReg Defined registration class of kernel.

View File

@ -139,6 +139,7 @@ set(LITE_SRC
${LITE_DIR}/src/registry/kernel_interface.cc
${LITE_DIR}/src/registry/kernel_interface_registry.cc
${LITE_DIR}/src/registry/register_kernel.cc
${LITE_DIR}/src/registry/register_utils.cc
${LITE_DIR}/src/registry/register_kernel_impl.cc
${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/ms_tensor.cc

View File

@ -18,6 +18,7 @@
#include <memory>
#include "include/errorcode.h"
#include "include/registry/register_kernel.h"
#include "src/registry/register_utils.h"
#include "src/ops/populate/populate_register.h"
#include "src/common/version_manager.h"
#include "nnacl/pooling_parameter.h"
@ -138,7 +139,7 @@ int KernelRegistry::GetCustomKernel(const std::vector<Tensor *> &in_tensors, con
MS_ASSERT(kernel != nullptr);
kernel::KernelDesc desc;
KernelKeyToKernelDesc(key, &desc);
CreateKernel creator = kernel::RegisterKernel::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
CreateKernel creator = kernel::RegisterUtils::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
if (creator == nullptr) {
return RET_NOT_SUPPORT;
}

View File

@ -29,9 +29,5 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
CreateKernel creator) {
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
}
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
}
} // namespace kernel
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <vector>
#include <set>
#include "include/registry/register_kernel.h"
#include "src/registry/register_utils.h"
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/registry/register_utils.h"
#include "src/registry/register_kernel_impl.h"
namespace mindspore {
namespace kernel {
CreateKernel RegisterUtils::GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc) {
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
#define MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_
#include <string>
#include "include/registry/register_kernel.h"
#include "schema/model_generated.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace kernel {
/// \brief KernelDesc defined kernel's basic attribute.
struct KernelDesc {
TypeId data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::string arch; /**< deviceType argument */
std::string provider; /**< user identification argument */
bool operator<(const KernelDesc &dst) const {
if (provider != dst.provider) {
return provider < dst.provider;
} else if (arch != dst.arch) {
return arch < dst.arch;
} else if (data_type != dst.data_type) {
return data_type < dst.data_type;
} else {
return type < dst.type;
}
}
};
/// \brief RegisterKernel Defined registration of kernel.
class RegisterUtils {
public:
/// \brief Static methon to get a kernel's create function.
///
/// \param[in] desc Define kernel's basic attribute.
/// \param[in] primitive Define the attributes of op.
///
/// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, kernel::KernelDesc *desc);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_UTILS_H_

View File

@ -141,6 +141,8 @@ set(LITE_SRC
${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc
${SRC_DIR}/registry/kernel_interface_registry.cc
${SRC_DIR}/registry/register_utils.cc
${SRC_DIR}/registry/register_kernel_impl.cc
${SRC_DIR}/registry/kernel_interface.cc
${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/inner_kernel.cc