forked from mindspore-Ecosystem/mindspore
!26101 【MS】【LITE】infershape support get attr from kernel
Merge pull request !26101 from chenjianping/code_check
This commit is contained in:
commit
525df1ca30
|
@ -25,6 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class Kernel;
|
||||
/// \brief KernelInterface defined customized op's interface, such as infershape, and so on.
|
||||
class MS_API KernelInterface {
|
||||
public:
|
||||
|
@ -42,6 +43,19 @@ class MS_API KernelInterface {
|
|||
const schema::Primitive *primitive) {
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
/// \brief Method to infer customized op's output shape.
|
||||
///
|
||||
/// \param[in] inputs Define the input tensors of op.
|
||||
/// \param[in] outputs Define the output tensors of op.
|
||||
/// \param[in] primitive Define the attributes of op.
|
||||
/// \param[in] kernel Define the kernel of a certain op.
|
||||
///
|
||||
/// \return Status as a status identification of inferring.
|
||||
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive, const Kernel *kernel) {
|
||||
return Infer(inputs, outputs, primitive);
|
||||
}
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,7 +70,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
|
|||
std::vector<mindspore::MSTensor> out_tensors;
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors),
|
||||
[](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); });
|
||||
auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive));
|
||||
auto ret =
|
||||
kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive), kernel);
|
||||
if (ret == kLiteInferInvalid) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue