!26101 【MS】【LITE】infershape support get attr from kernel

Merge pull request !26101 from chenjianping/code_check
This commit is contained in:
i-robot 2021-11-11 04:32:21 +00:00 committed by Gitee
commit 525df1ca30
2 changed files with 16 additions and 1 deletions

View File

@ -25,6 +25,7 @@
namespace mindspore {
namespace kernel {
class Kernel;
/// \brief KernelInterface defined customized op's interface, such as infershape, and so on.
class MS_API KernelInterface {
public:
@ -42,6 +43,19 @@ class MS_API KernelInterface {
const schema::Primitive *primitive) {
return kSuccess;
}
/// \brief Method to infer customized op's output shape.
///
/// \param[in] inputs Define the input tensors of op.
/// \param[in] outputs Define the output tensors of op.
/// \param[in] primitive Define the attributes of op.
/// \param[in] kernel Define the kernel of a certain op.
///
/// \return Status as a status identification of inferring.
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive, const Kernel *kernel) {
return Infer(inputs, outputs, primitive);
}
};
} // namespace kernel
} // namespace mindspore

View File

@ -70,7 +70,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
std::vector<mindspore::MSTensor> out_tensors;
std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors),
[](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); });
auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive));
auto ret =
kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive), kernel);
if (ret == kLiteInferInvalid) {
return RET_INFER_INVALID;
}