forked from mindspore-Ecosystem/mindspore
!31666 [MS][LITE] new core ops api and lite adapter new api
Merge pull request !31666 from luoyuan/core2
This commit is contained in:
commit
0b6f330d7e
|
@ -24,6 +24,7 @@
|
|||
#include "utils/shape_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define USE_DEPRECATED_API
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
|
@ -69,8 +69,8 @@ class RegisterStandardPrimitiveEvalHelper {
|
|||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
|
||||
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
|
||||
auto out = std::make_shared<name>(); \
|
||||
return out; \
|
||||
name out; \
|
||||
return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl()); \
|
||||
} \
|
||||
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
||||
} // namespace abstract
|
||||
|
|
|
@ -114,5 +114,10 @@ enum PaddingMode : int64_t {
|
|||
SYMMETRIC = 2,
|
||||
MODE_RESERVED = 3,
|
||||
};
|
||||
|
||||
enum PoolMode : int64_t {
|
||||
MAX_POOLING = 0,
|
||||
MEAN_POOLING = 1,
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_
|
||||
|
|
|
@ -153,5 +153,7 @@ class MIND_API AbstractTuple : public AbstractSequence {
|
|||
/// \param[in] elements A list of abstracts.
|
||||
explicit AbstractTuple(const AbstractBasePtrList &elements);
|
||||
};
|
||||
|
||||
using AbstractTuplePtr = SharedPtr<AbstractTuple>;
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
||||
|
|
|
@ -47,5 +47,9 @@ using FuncGraphPtr = SharedPtr<FuncGraph>;
|
|||
|
||||
class FuncGraphManager;
|
||||
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
|
||||
|
||||
class CNode;
|
||||
using CNodePtr = SharedPtr<CNode>;
|
||||
using CNodePtrList = std::vector<CNodePtr>;
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -49,6 +50,7 @@ TypePtr LayerNormBetaGammaBackpropInferType(const PrimitivePtr &prim, const std:
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(LayerNormBetaGammaBackprop, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -20,23 +20,22 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class MS_CORE_API LayerNormBetaGammaBackprop : public PrimitiveC {
|
||||
class MIND_API LayerNormBetaGammaBackprop : public BaseOperator {
|
||||
public:
|
||||
LayerNormBetaGammaBackprop() : PrimitiveC(prim::kPrimLayerNormBetaGammaBackprop->name()) {}
|
||||
~LayerNormBetaGammaBackprop() = default;
|
||||
MS_DECLARE_PARENT(LayerNormBetaGammaBackprop, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(LayerNormBetaGammaBackprop);
|
||||
LayerNormBetaGammaBackprop() : BaseOperator("LayerNormBetaGammaBackprop") {}
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -44,6 +45,7 @@ TypePtr LayerNormXBackpropInferType(const PrimitivePtr &prim, const std::vector<
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(LayerNormXBackprop, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -20,23 +20,21 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class MS_CORE_API LayerNormXBackprop : public PrimitiveC {
|
||||
class MIND_API LayerNormXBackprop : public BaseOperator {
|
||||
public:
|
||||
LayerNormXBackprop() : PrimitiveC(prim::kPrimLayerNormXBackprop->name()) {}
|
||||
~LayerNormXBackprop() = default;
|
||||
MS_DECLARE_PARENT(LayerNormXBackprop, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(LayerNormXBackprop);
|
||||
LayerNormXBackprop() : BaseOperator("LayerNormXBackprop") {}
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -142,6 +143,8 @@ ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
return result_tensor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Abs, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer, AbsInferValue, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,21 +19,18 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Returns absolute value of a tensor element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.Abs for more details.
|
||||
class MS_CORE_API Abs : public PrimitiveC {
|
||||
class MIND_API Abs : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Abs);
|
||||
/// \brief Constructor.
|
||||
Abs() : PrimitiveC(prim::kPrimAbs->name()) { InitIOName({"input_x"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~Abs() = default;
|
||||
MS_DECLARE_PARENT(Abs, PrimitiveC);
|
||||
Abs() : BaseOperator("Abs") { InitIOName({"input_x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Abs for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
|
||||
#include "ops/accumulate_n_v2.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -83,6 +85,7 @@ TypePtr AccumulateNV2InferType(const PrimitivePtr &prim, const std::vector<Abstr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(AccumulateNV2, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,21 +19,19 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAccumulateNV2 = "AccumulateNV2";
|
||||
class MS_CORE_API AccumulateNV2 : public PrimitiveC {
|
||||
class MIND_API AccumulateNV2 : public BaseOperator {
|
||||
public:
|
||||
AccumulateNV2() : PrimitiveC(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
|
||||
~AccumulateNV2() = default;
|
||||
MS_DECLARE_PARENT(AccumulateNV2, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(AccumulateNV2);
|
||||
AccumulateNV2() : BaseOperator(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
|
||||
};
|
||||
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimAccumulateNV2Ptr = std::shared_ptr<AccumulateNV2>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,15 @@
|
|||
*/
|
||||
|
||||
#include "ops/acos.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -38,6 +47,7 @@ TypePtr ACosInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ACos, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,28 +22,23 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameACos = "ACos";
|
||||
/// \brief Computes arccosine of input tensors element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.ACos for more details.
|
||||
class ACos : public PrimitiveC {
|
||||
class MIND_API ACos : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ACos);
|
||||
/// \brief Constructor.
|
||||
ACos() : PrimitiveC(kNameACos) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Destructor.
|
||||
~ACos() = default;
|
||||
|
||||
MS_DECLARE_PARENT(ACos, PrimitiveC);
|
||||
ACos() : BaseOperator(kNameACos) { InitIOName({"x"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimACosPtr = std::shared_ptr<ACos>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "ops/acosh.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -42,6 +46,7 @@ TypePtr AcoshInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Acosh, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,28 +22,23 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAcosh = "Acosh";
|
||||
/// \brief Computes arccosh of input tensors element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.Acosh for more details.
|
||||
class Acosh : public PrimitiveC {
|
||||
class MIND_API Acosh : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Acosh);
|
||||
/// \brief Constructor.
|
||||
Acosh() : PrimitiveC(kNameAcosh) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Destructor.
|
||||
~Acosh() = default;
|
||||
|
||||
MS_DECLARE_PARENT(Acosh, PrimitiveC);
|
||||
Acosh() : BaseOperator(kNameAcosh) { InitIOName({"x"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimAcoshPtr = std::shared_ptr<Acosh>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -79,14 +80,18 @@ abstract::TupleShapePtr AdamInferShape(const PrimitivePtr &primitive, const std:
|
|||
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Adam, PrimitiveC, BaseOperator);
|
||||
void Adam::Init(const bool use_locking, const bool use_nesterov) {
|
||||
this->set_use_locking(use_locking);
|
||||
this->set_use_nesterov(use_nesterov);
|
||||
}
|
||||
|
||||
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, api::MakeValue(use_locking)); }
|
||||
|
||||
void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
void Adam::set_use_nesterov(const bool use_nesterov) {
|
||||
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
|
||||
}
|
||||
|
||||
bool Adam::get_use_locking() const {
|
||||
auto value_ptr = GetAttr(kUseLocking);
|
||||
|
|
|
@ -20,22 +20,19 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdam = "Adam";
|
||||
/// \brief Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
|
||||
/// Refer to Python API @ref mindspore.ops.Adam for more details.
|
||||
class MS_CORE_API Adam : public PrimitiveC {
|
||||
class MIND_API Adam : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Adam);
|
||||
/// \brief Constructor.
|
||||
Adam() : PrimitiveC(kNameAdam) {}
|
||||
/// \brief Destructor.
|
||||
~Adam() = default;
|
||||
MS_DECLARE_PARENT(Adam, PrimitiveC);
|
||||
Adam() : BaseOperator(kNameAdam) {}
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Adam for the inputs.
|
||||
void Init(const bool use_locking = false, const bool use_nesterov = false);
|
||||
/// \brief Set use_locking.
|
||||
|
@ -51,8 +48,8 @@ class MS_CORE_API Adam : public PrimitiveC {
|
|||
/// \return use_nesterov.
|
||||
bool get_use_nesterov() const;
|
||||
};
|
||||
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimAdamPtr = std::shared_ptr<Adam>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,9 +21,11 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Add, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -20,28 +20,25 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdd = prim::kAdd;
|
||||
constexpr auto kNameAdd = "Add";
|
||||
/// \brief Adds two input tensors element-wise. Refer to Python API @ref mindspore.ops.Add for more details.
|
||||
class MS_CORE_API Add : public PrimitiveC {
|
||||
class MIND_API Add : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Add);
|
||||
/// \brief Constructor.
|
||||
Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
|
||||
explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~Add() = default;
|
||||
MS_DECLARE_PARENT(Add, PrimitiveC);
|
||||
Add() : BaseOperator(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
|
||||
explicit Add(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "y"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Add for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -73,6 +74,8 @@ TypePtr AddcdivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
return input_data_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Addcdiv, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,23 +19,20 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddcdiv = "Addcdiv";
|
||||
class Addcdiv : public PrimitiveC {
|
||||
class MIND_API Addcdiv : public BaseOperator {
|
||||
public:
|
||||
Addcdiv() : PrimitiveC(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
~Addcdiv() = default;
|
||||
MS_DECLARE_PARENT(Addcdiv, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(Addcdiv);
|
||||
Addcdiv() : BaseOperator(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimAddcdivPtr = std::shared_ptr<Addcdiv>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -73,6 +74,8 @@ TypePtr AddcmulInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
return input_data_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Addcmul, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,23 +19,20 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddcmul = "Addcmul";
|
||||
class Addcmul : public PrimitiveC {
|
||||
class MIND_API Addcmul : public BaseOperator {
|
||||
public:
|
||||
Addcmul() : PrimitiveC(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
~Addcmul() = default;
|
||||
MS_DECLARE_PARENT(Addcmul, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(Addcmul);
|
||||
Addcmul() : BaseOperator(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimAddcmulPtr = std::shared_ptr<Addcmul>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
|
||||
#include "ops/adder.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Adder, PrimitiveC, BaseOperator);
|
||||
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
|
||||
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
|
||||
|
@ -33,14 +35,16 @@ void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std:
|
|||
set_format(format);
|
||||
}
|
||||
|
||||
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, MakeValue(in_channel)); }
|
||||
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, api::MakeValue(in_channel)); }
|
||||
|
||||
int64_t Adder::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_out_channel(const int64_t out_channel) { (void)this->AddAttr(kOutChannel, MakeValue(out_channel)); }
|
||||
void Adder::set_out_channel(const int64_t out_channel) {
|
||||
(void)this->AddAttr(kOutChannel, api::MakeValue(out_channel));
|
||||
}
|
||||
|
||||
int64_t Adder::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
|
@ -48,7 +52,7 @@ int64_t Adder::get_out_channel() const {
|
|||
}
|
||||
|
||||
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
(void)this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
(void)this->AddAttr(kKernelSize, api::MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Adder::get_kernel_size() const {
|
||||
|
@ -58,7 +62,7 @@ std::vector<int64_t> Adder::get_kernel_size() const {
|
|||
|
||||
void Adder::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode Adder::get_pad_mode() const {
|
||||
|
@ -66,28 +70,32 @@ PadMode Adder::get_pad_mode() const {
|
|||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
|
||||
|
||||
std::vector<int64_t> Adder::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { (void)this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
(void)this->AddAttr(kPadList, api::MakeValue(pad_list));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Adder::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_dilation(const std::vector<int64_t> &dilation) { (void)this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
void Adder::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
(void)this->AddAttr(kDilation, api::MakeValue(dilation));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Adder::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, MakeValue(group)); }
|
||||
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, api::MakeValue(group)); }
|
||||
|
||||
int64_t Adder::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
|
@ -96,7 +104,7 @@ int64_t Adder::get_group() const {
|
|||
|
||||
void Adder::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, api::MakeValue(swi));
|
||||
}
|
||||
|
||||
Format Adder::get_format() const {
|
||||
|
|
|
@ -21,22 +21,19 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/base/format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdder = "Adder";
|
||||
/// \brief All defined All operator prototype of lite.
|
||||
class MS_CORE_API Adder : public PrimitiveC {
|
||||
class MIND_API Adder : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Adder);
|
||||
/// \brief Constructor.
|
||||
explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {}
|
||||
|
||||
/// \brief Destructor.
|
||||
~Adder() = default;
|
||||
MS_DECLARE_PARENT(Adder, PrimitiveC);
|
||||
explicit Adder(const std::string &k_name = kNameAdder) : BaseOperator(k_name) {}
|
||||
|
||||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <memory>
|
||||
#include "ops/addn.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -83,6 +85,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
return elements[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(AddN, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -18,27 +18,24 @@
|
|||
#define MINDSPORE_CORE_OPS_ADDN_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddN = "AddN";
|
||||
/// \brief Computes addition of all input tensors element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.AddN for more details.
|
||||
class MS_CORE_API AddN : public PrimitiveC {
|
||||
class MIND_API AddN : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AddN);
|
||||
/// \brief Constructor.
|
||||
AddN() : PrimitiveC(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
|
||||
/// \brief Destructor.
|
||||
~AddN() = default;
|
||||
MS_DECLARE_PARENT(AddN, PrimitiveC);
|
||||
AddN() : BaseOperator(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AddN for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
#include "ops/affine.h"
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Affine, PrimitiveC, BaseOperator);
|
||||
void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a, bool transpose_b) {
|
||||
this->set_context(contexts);
|
||||
this->set_output_dim(output_dim);
|
||||
|
@ -27,17 +30,17 @@ void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool
|
|||
}
|
||||
|
||||
void Affine::set_context(const std::vector<int64_t> &context) {
|
||||
(void)this->AddAttr(kAffineContext, MakeValue(context));
|
||||
(void)this->AddAttr(kAffineContext, api::MakeValue(context));
|
||||
}
|
||||
|
||||
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, MakeValue(output_dim)); }
|
||||
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, api::MakeValue(output_dim)); }
|
||||
|
||||
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
|
||||
|
||||
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
|
||||
|
||||
void Affine::set_activation_type(const ActivationType &activation_type) {
|
||||
(void)this->AddAttr(kActivationType, MakeValue(static_cast<int64_t>(activation_type)));
|
||||
(void)this->AddAttr(kActivationType, api::MakeValue(static_cast<int64_t>(activation_type)));
|
||||
}
|
||||
|
||||
bool Affine::get_transpose_a() const {
|
||||
|
|
|
@ -18,25 +18,21 @@
|
|||
#define MINDSPORE_CORE_OPS_AFFINE_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
constexpr auto kNameAffine = "Affine";
|
||||
constexpr auto kAffineContext = "context";
|
||||
constexpr auto kAffineOutputDim = "output_dim";
|
||||
|
||||
/// \brief Assert defined Affine operator prototype of lite.
|
||||
class MS_CORE_API Affine : public PrimitiveC {
|
||||
class MIND_API Affine : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Affine);
|
||||
/// \brief Constructor.
|
||||
Affine() : PrimitiveC(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
|
||||
/// \brief Destructor.
|
||||
~Affine() = default;
|
||||
MS_DECLARE_PARENT(Affine, PrimitiveC);
|
||||
Affine() : BaseOperator(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
|
||||
/// \brief Method to init the op's attributes.
|
||||
void Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a = false,
|
||||
bool transpose_b = false);
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
#include "ops/all.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(All, PrimitiveC, BaseOperator);
|
||||
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
|
||||
|
||||
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
|
||||
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
|
||||
|
||||
int64_t All::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
|
|
|
@ -16,23 +16,18 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_ALL_H_
|
||||
#define MINDSPORE_CORE_OPS_ALL_H_
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAll = "All";
|
||||
/// \brief All defined All operator prototype of lite.
|
||||
class MS_CORE_API All : public PrimitiveC {
|
||||
class MIND_API All : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(All);
|
||||
/// \brief Constructor.
|
||||
All() : PrimitiveC(kNameAll) {}
|
||||
|
||||
/// \brief Destructor.
|
||||
~All() = default;
|
||||
|
||||
MS_DECLARE_PARENT(All, PrimitiveC);
|
||||
All() : BaseOperator(kNameAll) {}
|
||||
|
||||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
#include "ops/all_gather.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(AllGather, PrimitiveC, BaseOperator);
|
||||
void AllGather::set_group(const string &group) {
|
||||
std::string g = group;
|
||||
(void)this->AddAttr(kGroup, MakeValue(g));
|
||||
(void)this->AddAttr(kGroup, api::MakeValue(g));
|
||||
}
|
||||
std::string AllGather::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
|
@ -30,7 +32,7 @@ std::string AllGather::get_group() const {
|
|||
}
|
||||
|
||||
void AllGather::set_rank_size(int rank_size) {
|
||||
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
|
||||
(void)this->AddAttr(kRankSize, api::MakeValue(static_cast<int64_t>(rank_size)));
|
||||
}
|
||||
int AllGather::get_rank_size() const {
|
||||
auto value_ptr = GetAttr(kRankSize);
|
||||
|
|
|
@ -20,18 +20,16 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAllGather = "AllGather";
|
||||
class MS_CORE_API AllGather : public PrimitiveC {
|
||||
class MIND_API AllGather : public BaseOperator {
|
||||
public:
|
||||
AllGather() : PrimitiveC(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
|
||||
~AllGather() = default;
|
||||
MS_DECLARE_PARENT(AllGather, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(AllGather);
|
||||
AllGather() : BaseOperator(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
|
||||
void Init() {}
|
||||
void set_group(const std::string &format);
|
||||
std::string get_group() const;
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -151,6 +153,7 @@ TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<Abstra
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdaMax, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -20,23 +20,21 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
|
||||
class ApplyAdaMax : public PrimitiveC {
|
||||
class MIND_API ApplyAdaMax : public BaseOperator {
|
||||
public:
|
||||
ApplyAdaMax() : PrimitiveC(kNameApplyAdaMax) {
|
||||
MIND_API_BASE_MEMBER(ApplyAdaMax);
|
||||
ApplyAdaMax() : BaseOperator(kNameApplyAdaMax) {
|
||||
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
|
||||
}
|
||||
~ApplyAdaMax() = default;
|
||||
MS_DECLARE_PARENT(ApplyAdaMax, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -121,6 +122,8 @@ TuplePtr ApplyAdadeltaInferType(const PrimitivePtr &primitive, const std::vector
|
|||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type, accum_update_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdadelta, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = ApplyAdadeltaInferType(primitive, input_args);
|
||||
|
|
|
@ -22,23 +22,21 @@
|
|||
#include <set>
|
||||
#include <map>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdadelta = "ApplyAdadelta";
|
||||
class ApplyAdadelta : public PrimitiveC {
|
||||
class MIND_API ApplyAdadelta : public BaseOperator {
|
||||
public:
|
||||
ApplyAdadelta() : PrimitiveC(kNameApplyAdadelta) {
|
||||
MIND_API_BASE_MEMBER(ApplyAdadelta);
|
||||
ApplyAdadelta() : BaseOperator(kNameApplyAdadelta) {
|
||||
InitIOName({"var", "accum", "accum_update", "lr", "rho", "epsilon", "grad"}, {"var", "accum", "accum_update"});
|
||||
}
|
||||
~ApplyAdadelta() = default;
|
||||
MS_DECLARE_PARENT(ApplyAdadelta, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyAdadeltaPtr = std::shared_ptr<ApplyAdadelta>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -81,6 +83,7 @@ TuplePtr ApplyAdagradInferType(const PrimitivePtr &primitive, const std::vector<
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdagrad, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,22 +22,20 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdagrad = "ApplyAdagrad";
|
||||
class ApplyAdagrad : public PrimitiveC {
|
||||
class MIND_API ApplyAdagrad : public BaseOperator {
|
||||
public:
|
||||
ApplyAdagrad() : PrimitiveC(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
~ApplyAdagrad() = default;
|
||||
MS_DECLARE_PARENT(ApplyAdagrad, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(ApplyAdagrad);
|
||||
ApplyAdagrad() : BaseOperator(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -23,10 +23,10 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr ApplyAdagradDAInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -98,6 +98,7 @@ TuplePtr ApplyAdagradDAInferType(const PrimitivePtr &prim, const std::vector<Abs
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdagradDA, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,31 +22,26 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdagradDA = "ApplyAdagradDA";
|
||||
/// \brief Update var according to the proximal adagrad scheme.
|
||||
/// Refer to Python API @ref mindspore.ops.ApplyAdagradDA for more details.
|
||||
class ApplyAdagradDA : public PrimitiveC {
|
||||
class MIND_API ApplyAdagradDA : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ApplyAdagradDA);
|
||||
/// \brief Constructor.
|
||||
ApplyAdagradDA() : PrimitiveC(kNameApplyAdagradDA) {
|
||||
ApplyAdagradDA() : BaseOperator(kNameApplyAdagradDA) {
|
||||
InitIOName({"var", "gradient_accumulator", "gradient_squared_accumulator", "grad", "lr", "l1", "l2", "global_step"},
|
||||
{"var", "gradient_accumulator", "gradient_squared_accumulator"});
|
||||
}
|
||||
|
||||
/// \brief Destructor.
|
||||
~ApplyAdagradDA() = default;
|
||||
|
||||
MS_DECLARE_PARENT(ApplyAdagradDA, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -77,6 +78,7 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vector<Abs
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdagradV2, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,23 +22,19 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdagradV2 = "ApplyAdagradV2";
|
||||
class ApplyAdagradV2 : public PrimitiveC {
|
||||
class MIND_API ApplyAdagradV2 : public BaseOperator {
|
||||
public:
|
||||
ApplyAdagradV2() : PrimitiveC(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
|
||||
~ApplyAdagradV2() = default;
|
||||
|
||||
MS_DECLARE_PARENT(ApplyAdagradV2, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(ApplyAdagradV2);
|
||||
ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
};
|
||||
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyAdagradV2Ptr = std::shared_ptr<ApplyAdagradV2>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -91,6 +93,7 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAdamWithAmsgrad, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,24 +19,22 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
|
||||
class ApplyAdamWithAmsgrad : public PrimitiveC {
|
||||
class MIND_API ApplyAdamWithAmsgrad : public BaseOperator {
|
||||
public:
|
||||
ApplyAdamWithAmsgrad() : PrimitiveC(kNameApplyAdamWithAmsgrad) {
|
||||
MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad);
|
||||
ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) {
|
||||
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
|
||||
}
|
||||
~ApplyAdamWithAmsgrad() = default;
|
||||
MS_DECLARE_PARENT(ApplyAdamWithAmsgrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -116,6 +117,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vector<Abstr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyAddSign, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -21,27 +21,23 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAddSign = "ApplyAddSign";
|
||||
|
||||
class ApplyAddSign : public PrimitiveC {
|
||||
class MIND_API ApplyAddSign : public BaseOperator {
|
||||
public:
|
||||
ApplyAddSign() : PrimitiveC(kNameApplyAddSign) {
|
||||
MIND_API_BASE_MEMBER(ApplyAddSign);
|
||||
ApplyAddSign() : BaseOperator(kNameApplyAddSign) {
|
||||
InitIOName({"var", "m", "lr", "alpha", "sign_decay", "beta", "grad"}, {"var", "m"});
|
||||
}
|
||||
|
||||
~ApplyAddSign() = default;
|
||||
|
||||
MS_DECLARE_PARENT(ApplyAddSign, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyAddSignPtr = std::shared_ptr<ApplyAddSign>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ops/apply_centered_rms_prop.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -106,6 +107,8 @@ TypePtr ApplyCenteredRMSPropInferType(const PrimitivePtr &primitive, const std::
|
|||
return var_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyCenteredRMSProp, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = ApplyCenteredRMSPropInferType(primitive, input_args);
|
||||
|
|
|
@ -22,25 +22,23 @@
|
|||
#include <set>
|
||||
#include <map>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyCenteredRMSProp = "ApplyCenteredRMSProp";
|
||||
class ApplyCenteredRMSProp : public PrimitiveC {
|
||||
class MIND_API ApplyCenteredRMSProp : public BaseOperator {
|
||||
public:
|
||||
ApplyCenteredRMSProp() : PrimitiveC(kNameApplyCenteredRMSProp) {
|
||||
MIND_API_BASE_MEMBER(ApplyCenteredRMSProp);
|
||||
ApplyCenteredRMSProp() : BaseOperator(kNameApplyCenteredRMSProp) {
|
||||
InitIOName(
|
||||
{"var", "mean_gradient", "mean_square", "moment", "grad", "learning_rate", "decay", "momentum", "epsilon"},
|
||||
{"var", "mean_gradient", "mean_square", "moment"});
|
||||
}
|
||||
~ApplyCenteredRMSProp() = default;
|
||||
MS_DECLARE_PARENT(ApplyCenteredRMSProp, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyCenteredRMSPropPtr = std::shared_ptr<ApplyCenteredRMSProp>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -91,6 +92,8 @@ TypePtr ApplyFtrlInferType(const PrimitivePtr &prim, const std::vector<AbstractB
|
|||
return var_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyFtrl, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,24 +22,21 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyFtrl = "ApplyFtrl";
|
||||
class ApplyFtrl : public PrimitiveC {
|
||||
class MIND_API ApplyFtrl : public BaseOperator {
|
||||
public:
|
||||
ApplyFtrl() : PrimitiveC(kNameApplyFtrl) {
|
||||
MIND_API_BASE_MEMBER(ApplyFtrl);
|
||||
ApplyFtrl() : BaseOperator(kNameApplyFtrl) {
|
||||
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
|
||||
}
|
||||
|
||||
~ApplyFtrl() = default;
|
||||
MS_DECLARE_PARENT(ApplyFtrl, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -72,6 +73,7 @@ TypePtr ApplyGradientDescentInferType(const PrimitivePtr &prim, const std::vecto
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyGradientDescent, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,24 +22,20 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyGradientDescent = "ApplyGradientDescent";
|
||||
class ApplyGradientDescent : public PrimitiveC {
|
||||
class MIND_API ApplyGradientDescent : public BaseOperator {
|
||||
public:
|
||||
ApplyGradientDescent() : PrimitiveC(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
|
||||
|
||||
~ApplyGradientDescent() = default;
|
||||
|
||||
MS_DECLARE_PARENT(ApplyGradientDescent, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(ApplyGradientDescent);
|
||||
ApplyGradientDescent() : BaseOperator(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimApplyGradientDescentPtr = std::shared_ptr<ApplyGradientDescent>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -81,6 +83,7 @@ TuplePtr ApplyKerasMomentumInferType(const PrimitivePtr &prim, const std::vector
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyKerasMomentum, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,24 +22,22 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
|
||||
class MS_CORE_API ApplyKerasMomentum : public PrimitiveC {
|
||||
class MIND_API ApplyKerasMomentum : public BaseOperator {
|
||||
public:
|
||||
ApplyKerasMomentum() : PrimitiveC(kNameApplyKerasMomentum) {
|
||||
MIND_API_BASE_MEMBER(ApplyKerasMomentum);
|
||||
ApplyKerasMomentum() : BaseOperator(kNameApplyKerasMomentum) {
|
||||
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
||||
}
|
||||
~ApplyKerasMomentum() = default;
|
||||
MS_DECLARE_PARENT(ApplyKerasMomentum, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -30,15 +31,15 @@ void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const
|
|||
}
|
||||
|
||||
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
|
||||
(void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
|
||||
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_use_locking(const bool use_locking) {
|
||||
(void)this->AddAttr(kUseLocking, MakeValue(use_locking));
|
||||
(void)this->AddAttr(kUseLocking, api::MakeValue(use_locking));
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
|
||||
(void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
||||
(void)this->AddAttr(kGradientScale, api::MakeValue(gradient_scale));
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_nesterov() const {
|
||||
|
@ -102,6 +103,8 @@ TypePtr ApplyMomentumInferType(const PrimitivePtr &primitive, const std::vector<
|
|||
return v_tensor_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyMomentum, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = ApplyMomentumInferType(primitive, input_args);
|
||||
|
|
|
@ -22,24 +22,21 @@
|
|||
#include <set>
|
||||
#include <map>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyMomentum = "ApplyMomentum";
|
||||
/// \brief Optimizer that implements the Momentum algorithm.
|
||||
/// Refer to Python API @ref mindspore.ops.ApplyMomentum for more details.
|
||||
class MS_CORE_API ApplyMomentum : public PrimitiveC {
|
||||
class MIND_API ApplyMomentum : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ApplyMomentum);
|
||||
/// \brief Constructor.
|
||||
ApplyMomentum() : PrimitiveC(kNameApplyMomentum) {
|
||||
ApplyMomentum() : BaseOperator(kNameApplyMomentum) {
|
||||
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
||||
}
|
||||
/// \brief Destructor.
|
||||
~ApplyMomentum() = default;
|
||||
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ApplyMomentum for the inputs.
|
||||
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
|
||||
/// \brief Set use_nesterov.
|
||||
|
@ -61,8 +58,8 @@ class MS_CORE_API ApplyMomentum : public PrimitiveC {
|
|||
/// \return gradient_scale.
|
||||
float get_gradient_scale() const;
|
||||
};
|
||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -108,6 +109,7 @@ TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<Ab
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyPowerSign, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,23 +19,21 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
|
||||
class ApplyPowerSign : public PrimitiveC {
|
||||
class MIND_API ApplyPowerSign : public BaseOperator {
|
||||
public:
|
||||
ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) {
|
||||
MIND_API_BASE_MEMBER(ApplyPowerSign);
|
||||
ApplyPowerSign() : BaseOperator(kNameApplyPowerSign) {
|
||||
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
|
||||
}
|
||||
~ApplyPowerSign() = default;
|
||||
MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -99,6 +101,7 @@ TuplePtr ApplyProximalAdagradInferType(const PrimitivePtr &primitive, const std:
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyProximalAdagrad, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,24 +22,22 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyProximalAdagrad = "ApplyProximalAdagrad";
|
||||
class ApplyProximalAdagrad : public PrimitiveC {
|
||||
class MIND_API ApplyProximalAdagrad : public BaseOperator {
|
||||
public:
|
||||
ApplyProximalAdagrad() : PrimitiveC(kNameApplyProximalAdagrad) {
|
||||
MIND_API_BASE_MEMBER(ApplyProximalAdagrad);
|
||||
ApplyProximalAdagrad() : BaseOperator(kNameApplyProximalAdagrad) {
|
||||
InitIOName({"var", "accum", "lr", "l1", "l2", "grad"}, {"var", "accum"});
|
||||
}
|
||||
~ApplyProximalAdagrad() = default;
|
||||
MS_DECLARE_PARENT(ApplyProximalAdagrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using kPrimApplyProximalAdagradPtr = std::shared_ptr<ApplyProximalAdagrad>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -98,6 +99,7 @@ TypePtr ApplyProximalGradientDescentInferType(const PrimitivePtr &prim,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApplyProximalGradientDescent, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t input_num = 5;
|
||||
|
|
|
@ -19,23 +19,22 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyProximalGradientDescent = "ApplyProximalGradientDescent";
|
||||
class ApplyProximalGradientDescent : public PrimitiveC {
|
||||
class MIND_API ApplyProximalGradientDescent : public BaseOperator {
|
||||
public:
|
||||
ApplyProximalGradientDescent() : PrimitiveC(kNameApplyProximalGradientDescent) {
|
||||
MIND_API_BASE_MEMBER(ApplyProximalGradientDescent);
|
||||
ApplyProximalGradientDescent() : BaseOperator(kNameApplyProximalGradientDescent) {
|
||||
InitIOName({"var", "alpha", "l1", "l2", "delta"}, {"var"});
|
||||
}
|
||||
~ApplyProximalGradientDescent() = default;
|
||||
MS_DECLARE_PARENT(ApplyProximalGradientDescent, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -59,6 +62,8 @@ TypePtr ApproximateEqualInferType(const PrimitivePtr &prim, const std::vector<Ab
|
|||
return y_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(ApproximateEqual, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,21 +19,18 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class ApproximateEqual : public PrimitiveC {
|
||||
class MIND_API ApproximateEqual : public BaseOperator {
|
||||
public:
|
||||
ApproximateEqual() : PrimitiveC(prim::kPrimApproximateEqual->name()) {}
|
||||
~ApproximateEqual() = default;
|
||||
MS_DECLARE_PARENT(ApproximateEqual, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(ApproximateEqual);
|
||||
ApproximateEqual() : BaseOperator("ApproximateEqual") {}
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimApproximateEqualPtr = std::shared_ptr<ApproximateEqual>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "ops/arg_max.h"
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -23,15 +27,17 @@ void ArgMax::Init(const int64_t axis, const TypeId output_type) {
|
|||
set_output_type(output_type);
|
||||
}
|
||||
|
||||
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void ArgMax::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
||||
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
void ArgMax::set_output_type(const TypeId output_type) {
|
||||
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
|
||||
}
|
||||
|
||||
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
TypeId ArgMax::get_output_type() const {
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(ArgMax, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,24 +20,21 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameArgMax = "Argmax";
|
||||
/// \brief Returns the indices of the maximum value of a tensor across the axis.
|
||||
/// Refer to Python API @ref mindspore.ops.Argmax for more details.
|
||||
class MS_CORE_API ArgMax : public PrimitiveC {
|
||||
class MIND_API ArgMax : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ArgMax);
|
||||
/// \brief Constructor.
|
||||
ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~ArgMax() = default;
|
||||
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
|
||||
ArgMax() : BaseOperator(kNameArgMax) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMax(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmax for the inputs.
|
||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||
/// \brief Set axis.
|
||||
|
@ -54,8 +51,8 @@ class MS_CORE_API ArgMax : public PrimitiveC {
|
|||
/// \return output_type.
|
||||
TypeId get_output_type() const;
|
||||
};
|
||||
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -16,21 +16,28 @@
|
|||
|
||||
#include <set>
|
||||
#include "ops/arg_min.h"
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(ArgMin, PrimitiveC, BaseOperator);
|
||||
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
|
||||
set_axis(axis);
|
||||
set_output_type(output_type);
|
||||
}
|
||||
|
||||
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void ArgMin::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
||||
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
void ArgMin::set_output_type(const TypeId output_type) {
|
||||
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
|
||||
}
|
||||
|
||||
int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
|
||||
TypeId ArgMin::get_output_type() const {
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
|
||||
|
|
|
@ -20,24 +20,21 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameArgMin = "ArgMin";
|
||||
/// \brief Returns the indices of the minimum value of a tensor across the axis.
|
||||
/// Refer to Python API @ref mindspore.ops.Argmin for more details.
|
||||
class MS_CORE_API ArgMin : public PrimitiveC {
|
||||
class MIND_API ArgMin : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ArgMin);
|
||||
/// \brief Constructor.
|
||||
ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~ArgMin() = default;
|
||||
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
|
||||
ArgMin() : BaseOperator(kNameArgMin) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMin(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmin for the inputs.
|
||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||
/// \brief Set axis.
|
||||
|
@ -54,8 +51,8 @@ class MS_CORE_API ArgMin : public PrimitiveC {
|
|||
/// \return output_type.
|
||||
TypeId get_output_type() const;
|
||||
};
|
||||
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimArgMin = std::shared_ptr<ArgMin>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -42,6 +43,7 @@ TypePtr AsinInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Asin, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,30 +22,26 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAsin = "Asin";
|
||||
/// \brief Computes arcsine of input tensors element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.Asin for more details.
|
||||
class MS_CORE_API Asin : public PrimitiveC {
|
||||
class MIND_API Asin : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Asin);
|
||||
/// \brief Constructor.
|
||||
Asin() : PrimitiveC(kNameAsin) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Destructor.
|
||||
~Asin() = default;
|
||||
|
||||
MS_DECLARE_PARENT(Asin, PrimitiveC);
|
||||
Asin() : BaseOperator(kNameAsin) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Asin for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimAsinPtr = std::shared_ptr<Asin>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -15,6 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "ops/asinh.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -42,6 +47,7 @@ TypePtr AsinhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Asinh, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -22,29 +22,24 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAsinh = "Asinh";
|
||||
/// \brief Computes arcsinh of input tensors element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.Asinh for more details.
|
||||
class MS_CORE_API Asinh : public PrimitiveC {
|
||||
class MIND_API Asinh : public BaseOperator {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Asinh() : PrimitiveC(kNameAsinh) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Destructor.
|
||||
~Asinh() = default;
|
||||
|
||||
MS_DECLARE_PARENT(Asinh, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(Asinh);
|
||||
Asinh() : BaseOperator(kNameAsinh) { InitIOName({"x"}, {"y"}); }
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimAsinhPtr = std::shared_ptr<Asinh>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -21,13 +21,16 @@
|
|||
#include <memory>
|
||||
|
||||
#include "ops/assert.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Assert, PrimitiveC, BaseOperator);
|
||||
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
|
||||
|
||||
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, MakeValue(summarize)); }
|
||||
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, api::MakeValue(summarize)); }
|
||||
|
||||
int64_t Assert::get_summarize() const {
|
||||
auto value_ptr = GetAttr(kSummarize);
|
||||
|
|
|
@ -18,23 +18,18 @@
|
|||
#define MINDSPORE_CORE_OPS_ASSERT_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssert = "Assert";
|
||||
/// \brief Assert defined Assert operator prototype of lite.
|
||||
class MS_CORE_API Assert : public PrimitiveC {
|
||||
class MIND_API Assert : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Assert);
|
||||
/// \brief Constructor.
|
||||
Assert() : PrimitiveC(kNameAssert) {}
|
||||
|
||||
/// \brief Destructor.
|
||||
~Assert() = default;
|
||||
|
||||
MS_DECLARE_PARENT(Assert, PrimitiveC);
|
||||
Assert() : BaseOperator(kNameAssert) {}
|
||||
|
||||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
|
@ -52,8 +47,8 @@ class MS_CORE_API Assert : public PrimitiveC {
|
|||
int64_t get_summarize() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -23,9 +23,12 @@
|
|||
#include "ops/assign.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ir/dtype/ref.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Assign, PrimitiveC, BaseOperator);
|
||||
abstract::ShapePtr AssignInferShape(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
|
|
|
@ -19,21 +19,18 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssign = "Assign";
|
||||
/// \brief Assigns Parameter with a value. Refer to Python API @ref mindspore.ops.Assign for more details.
|
||||
class MS_CORE_API Assign : public PrimitiveC {
|
||||
class MIND_API Assign : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Assign);
|
||||
/// \brief Constructor.
|
||||
Assign() : PrimitiveC(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~Assign() = default;
|
||||
MS_DECLARE_PARENT(Assign, PrimitiveC);
|
||||
Assign() : BaseOperator(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Assign for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "ops/assign_add.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -39,6 +40,8 @@ TypePtr AssignAddInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
|||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(AssignAdd, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,27 +19,24 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssignAdd = "AssignAdd";
|
||||
/// \brief Updates a Parameter by adding a value to it.
|
||||
/// Refer to Python API @ref mindspore.ops.AssignAdd for more details.
|
||||
class MS_CORE_API AssignAdd : public PrimitiveC {
|
||||
class MIND_API AssignAdd : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AssignAdd);
|
||||
/// \brief Constructor.
|
||||
AssignAdd() : PrimitiveC(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~AssignAdd() = default;
|
||||
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
|
||||
AssignAdd() : BaseOperator(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AssignAdd for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimAssignAddPtr = std::shared_ptr<AssignAdd>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <string>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -41,6 +42,7 @@ TypePtr AssignSubInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(AssignSub, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -19,22 +19,20 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssignSub = "AssignSub";
|
||||
class AssignSub : public PrimitiveC {
|
||||
class MIND_API AssignSub : public BaseOperator {
|
||||
public:
|
||||
AssignSub() : PrimitiveC(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
|
||||
~AssignSub() = default;
|
||||
MS_DECLARE_PARENT(AssignSub, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(AssignSub);
|
||||
AssignSub() : BaseOperator(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimAssignSubPtr = std::shared_ptr<AssignSub>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -52,6 +53,8 @@ TypePtr AtanInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
|||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Atan, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -20,27 +20,24 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAtan = "Atan";
|
||||
/// \brief Computes the trigonometric inverse tangent of the input element-wise.
|
||||
/// Refer to Python API @ref mindspore.ops.Atan for more details.
|
||||
class MS_CORE_API Atan : public PrimitiveC {
|
||||
class MIND_API Atan : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Atan);
|
||||
/// \brief Constructor.
|
||||
Atan() : PrimitiveC(kNameAtan) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~Atan() = default;
|
||||
MS_DECLARE_PARENT(Atan, PrimitiveC);
|
||||
Atan() : BaseOperator(kNameAtan) {}
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Atan for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -58,6 +59,8 @@ TypePtr AtanhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
|||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Atanh, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = AtanhInferType(primitive, input_args);
|
||||
|
|
|
@ -20,22 +20,20 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAtanh = "Atanh";
|
||||
class Atanh : public PrimitiveC {
|
||||
class MIND_API Atanh : public BaseOperator {
|
||||
public:
|
||||
Atanh() : PrimitiveC(kNameAtanh) { InitIOName({"x"}, {"output"}); }
|
||||
~Atanh() = default;
|
||||
MS_DECLARE_PARENT(Atanh, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(Atanh);
|
||||
Atanh() : BaseOperator(kNameAtanh) { InitIOName({"x"}, {"output"}); }
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimAtanhPtr = std::shared_ptr<Atanh>;
|
||||
} // namespace ops
|
||||
|
|
|
@ -16,7 +16,10 @@
|
|||
*/
|
||||
|
||||
#include "ops/attention.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
MIND_API_BASE_IMPL(Attention, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||
} // namespace mindspore::ops
|
||||
|
|
|
@ -19,25 +19,23 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAttention = "Attention";
|
||||
/// \brief MultiHead-Attention op in MindIR.
|
||||
class MS_CORE_API Attention : public PrimitiveC {
|
||||
class MIND_API Attention : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Attention);
|
||||
/// \brief Constructor.
|
||||
Attention() : PrimitiveC(kNameAttention) {
|
||||
Attention() : BaseOperator(kNameAttention) {
|
||||
InitIOName(
|
||||
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
|
||||
{"output"});
|
||||
}
|
||||
/// \brief Destructor.
|
||||
~Attention() override = default;
|
||||
MS_DECLARE_PARENT(Attention, PrimitiveC);
|
||||
/// \brief Initialize Attention op.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -23,24 +23,28 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(AudioSpectrogram, PrimitiveC, BaseOperator);
|
||||
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
||||
(void)this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
(void)this->AddAttr(kWindowSize, api::MakeValue(window_size));
|
||||
}
|
||||
int64_t AudioSpectrogram::get_window_size() const {
|
||||
auto value_ptr = GetAttr(kWindowSize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
|
||||
int64_t AudioSpectrogram::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_mag_square(const bool mag_square) { (void)this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
||||
void AudioSpectrogram::set_mag_square(const bool mag_square) {
|
||||
(void)this->AddAttr(kMagSquare, api::MakeValue(mag_square));
|
||||
}
|
||||
bool AudioSpectrogram::get_mag_square() const {
|
||||
auto value_ptr = GetAttr(kMagSquare);
|
||||
return GetValue<bool>(value_ptr);
|
||||
|
|
|
@ -20,23 +20,18 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
|
||||
/// \brief AudioSpectrogram defined AudioSpectrogram operator prototype.
|
||||
class MS_CORE_API AudioSpectrogram : public PrimitiveC {
|
||||
class MIND_API AudioSpectrogram : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AudioSpectrogram);
|
||||
/// \brief Constructor.
|
||||
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
|
||||
|
||||
/// \brief Destructor.
|
||||
~AudioSpectrogram() = default;
|
||||
|
||||
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
|
||||
AudioSpectrogram() : BaseOperator(kNameAudioSpectrogram) {}
|
||||
|
||||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
|
@ -75,8 +70,8 @@ class MS_CORE_API AudioSpectrogram : public PrimitiveC {
|
|||
/// \return a boolean value.
|
||||
bool get_mag_square() const;
|
||||
};
|
||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -23,35 +23,37 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
||||
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
(void)this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
(void)this->AddAttr(
|
||||
kKernelSize, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
|
||||
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
(void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
(void)this->AddAttr(kStrides,
|
||||
api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
|
||||
|
||||
void AvgPool::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, api::MakeValue(f));
|
||||
}
|
||||
|
||||
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||
|
||||
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, api::MakeValue(pad)); }
|
||||
|
||||
std::vector<int64_t> AvgPool::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
|
@ -60,7 +62,7 @@ std::vector<int64_t> AvgPool::get_pad() const {
|
|||
|
||||
void AvgPool::set_round_mode(const RoundMode &round_mode) {
|
||||
int64_t swi = round_mode;
|
||||
(void)this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kRoundMode, api::MakeValue(swi));
|
||||
}
|
||||
|
||||
RoundMode AvgPool::get_round_mode() const {
|
||||
|
@ -78,6 +80,7 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
|
|||
this->set_round_mode(round_mode);
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(AvgPool, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,22 +21,20 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/base/format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAvgPool = "AvgPool";
|
||||
/// \brief Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool for more details.
|
||||
class MS_CORE_API AvgPool : public PrimitiveC {
|
||||
class MIND_API AvgPool : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AvgPool);
|
||||
/// \brief Constructor.
|
||||
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
||||
explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~AvgPool() = default;
|
||||
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
|
||||
AvgPool() : BaseOperator(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
||||
explicit AvgPool(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AvgPool for the inputs.
|
||||
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
|
||||
const PadMode &pad_mode = VALID, const Format &format = NCHW,
|
||||
|
@ -80,8 +78,8 @@ class MS_CORE_API AvgPool : public PrimitiveC {
|
|||
RoundMode get_round_mode() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -180,6 +181,7 @@ TypePtr AvgPool3DInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(AvgPool3D, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(AvgPool3DInferShape(primitive, input_args), AvgPool3DInferType(primitive, input_args));
|
||||
|
|
|
@ -21,24 +21,21 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief 3D Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool3D for more details.
|
||||
class MS_CORE_API AvgPool3D : public PrimitiveC {
|
||||
class MIND_API AvgPool3D : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AvgPool3D);
|
||||
/// \brief Constructor.
|
||||
AvgPool3D() : PrimitiveC(prim::kPrimAvgPool3D->name()) { InitIOName({"input"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~AvgPool3D() = default;
|
||||
MS_DECLARE_PARENT(AvgPool3D, PrimitiveC);
|
||||
AvgPool3D() : BaseOperator("AvgPool3D") { InitIOName({"input"}, {"output"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,13 +15,18 @@
|
|||
*/
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(BaseOperator, PrimitiveC, api::Primitive);
|
||||
BaseOperator::BaseOperator(const std::string &name) : api::Primitive(std::make_shared<PrimitiveC>(name)) {}
|
||||
|
||||
PrimitiveCPtr BaseOperator::GetPrim() {
|
||||
PrimitiveCPtr res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||
return res;
|
||||
}
|
||||
void BaseOperator::InitIOName(const std::vector<std::string> &inputs_name,
|
||||
const std::vector<std::string> &outputs_name) {
|
||||
(void)AddAttr("input_names", api::MakeValue(inputs_name));
|
||||
|
|
|
@ -17,26 +17,36 @@
|
|||
#ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
||||
#define MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mindapi/ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
class AnalysisEngine;
|
||||
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
|
||||
|
||||
class AbstractBase;
|
||||
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
|
||||
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
namespace mindspore {
|
||||
class Primitive;
|
||||
using PrimitivePtr = std::shared_ptr<Primitive>;
|
||||
} // namespace mindspore
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class BaseOperator : public api::Primitive {
|
||||
class PrimitiveC;
|
||||
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
|
||||
class MIND_API BaseOperator : public api::Primitive {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(BaseOperator);
|
||||
explicit BaseOperator(const std::string &name);
|
||||
~BaseOperator() = default;
|
||||
PrimitiveCPtr GetPrim();
|
||||
|
||||
protected:
|
||||
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -134,14 +135,15 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(BatchMatmul, PrimitiveC, BaseOperator);
|
||||
void BatchMatmul::Init(bool transpose_a, bool transpose_b) {
|
||||
set_transpose_a(transpose_a);
|
||||
set_transpose_b(transpose_b);
|
||||
}
|
||||
|
||||
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
|
||||
|
||||
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
|
||||
|
||||
bool BatchMatmul::get_transpose_a() const {
|
||||
auto value_ptr = GetAttr(kTransposeA);
|
||||
|
|
|
@ -18,21 +18,18 @@
|
|||
#define MINDSPORE_CORE_OPS_BATCH_MATMUL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Computes matrix multiplication between two tensors by batch.
|
||||
/// Refer to Python API @ref mindspore.ops.BatchMatmul for more details.
|
||||
class MS_CORE_API BatchMatmul : public PrimitiveC {
|
||||
class MIND_API BatchMatmul : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(BatchMatmul);
|
||||
/// \brief Constructor.
|
||||
BatchMatmul() : PrimitiveC(prim::kPrimBatchMatMul->name()) { InitIOName({"x1", "x2"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~BatchMatmul() = default;
|
||||
MS_DECLARE_PARENT(BatchMatmul, PrimitiveC);
|
||||
BatchMatmul() : BaseOperator("BatchMatMul") { InitIOName({"x1", "x2"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.BatchMatmul for the inputs.
|
||||
void Init(bool transpose_a = false, bool transpose_b = false);
|
||||
/// \brief Set transpose_a.
|
||||
|
@ -48,8 +45,8 @@ class MS_CORE_API BatchMatmul : public PrimitiveC {
|
|||
/// \return transpose_b.
|
||||
bool get_transpose_b() const;
|
||||
};
|
||||
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue