forked from mindspore-Ecosystem/mindspore
!15538 [MS][LITE] add call ops
From: @mengyuanli Reviewed-by: @hangangqiang,@zhang_xue_tong Signed-off-by: @hangangqiang
This commit is contained in:
commit
8c76ee5023
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/call.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameCall, Call);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CALL_H_
|
||||
#define MINDSPORE_CORE_OPS_CALL_H_
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCall = "call";
|
||||
class Call : public PrimitiveC {
|
||||
public:
|
||||
Call() : PrimitiveC(kNameCall) {}
|
||||
~Call() = default;
|
||||
MS_DECLARE_PARENT(Call, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CALL_H_
|
|
@ -207,6 +207,7 @@ union PrimitiveType {
|
|||
ResizeGrad,
|
||||
Splice,
|
||||
LogSoftmax,
|
||||
Call,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1099,3 +1100,6 @@ table Splice {
|
|||
table LogSoftmax {
|
||||
axis: long;
|
||||
}
|
||||
|
||||
table Call {
|
||||
}
|
||||
|
|
|
@ -206,6 +206,7 @@ OP_TYPE(LayerNormGrad)
|
|||
OP_TYPE(ResizeGrad)
|
||||
OP_TYPE(Splice)
|
||||
OP_TYPE(LogSoftmax)
|
||||
OP_TYPE(Call)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1098,3 +1099,6 @@ OP_SCHEMA_DEF_END(Splice)
|
|||
OP_SCHEMA_DEF(LogSoftmax)
|
||||
OP_ATTR(axis, long)
|
||||
OP_SCHEMA_DEF_END(LogSoftmax)
|
||||
|
||||
OP_SCHEMA_DEF(Call)
|
||||
OP_SCHEMA_DEF_END(Call)
|
||||
|
|
|
@ -244,6 +244,7 @@
|
|||
#include "ops/grad/abs_grad.h"
|
||||
#include "ops/splice.h"
|
||||
#include "ops/log_softmax.h"
|
||||
#include "ops/call.h"
|
||||
|
||||
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \
|
||||
namespace mindspore::lite::ops { \
|
||||
|
@ -457,5 +458,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad);
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Splice);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Call);
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_
|
||||
|
|
|
@ -755,6 +755,11 @@ schema::PrimitiveT *LogSoftmaxPrimitiveCreator(const AnfNodePtr &node) {
|
|||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
schema::PrimitiveT *CallPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Call>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
|
||||
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
|
||||
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
|
||||
|
@ -969,6 +974,7 @@ RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiv
|
|||
RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
|
||||
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
|
||||
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
|
||||
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue