!31666 [MS][LITE] new core ops api and lite adapter new api

Merge pull request !31666 from luoyuan/core2
This commit is contained in:
i-robot 2022-03-24 08:50:40 +00:00 committed by Gitee
commit 0b6f330d7e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1839 changed files with 12071 additions and 10885 deletions

View File

@ -24,6 +24,7 @@
#include "utils/shape_utils.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace abstract {

View File

@ -16,6 +16,7 @@
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "abstract/primitive_infer_map.h"
#include <string>
#include <vector>

View File

@ -69,8 +69,8 @@ class RegisterStandardPrimitiveEvalHelper {
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
auto out = std::make_shared<name>(); \
return out; \
name out; \
return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl()); \
} \
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
} // namespace abstract

View File

@ -114,5 +114,10 @@ enum PaddingMode : int64_t {
SYMMETRIC = 2,
MODE_RESERVED = 3,
};
enum PoolMode : int64_t {
MAX_POOLING = 0,
MEAN_POOLING = 1,
};
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_

View File

@ -153,5 +153,7 @@ class MIND_API AbstractTuple : public AbstractSequence {
/// \param[in] elements A list of abstracts.
explicit AbstractTuple(const AbstractBasePtrList &elements);
};
using AbstractTuplePtr = SharedPtr<AbstractTuple>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_

View File

@ -47,5 +47,9 @@ using FuncGraphPtr = SharedPtr<FuncGraph>;
class FuncGraphManager;
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
class CNode;
using CNodePtr = SharedPtr<CNode>;
using CNodePtrList = std::vector<CNodePtr>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_

View File

@ -21,6 +21,7 @@
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -49,6 +50,7 @@ TypePtr LayerNormBetaGammaBackpropInferType(const PrimitivePtr &prim, const std:
}
} // namespace
MIND_API_BASE_IMPL(LayerNormBetaGammaBackprop, PrimitiveC, BaseOperator);
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -20,23 +20,22 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
class MS_CORE_API LayerNormBetaGammaBackprop : public PrimitiveC {
class MIND_API LayerNormBetaGammaBackprop : public BaseOperator {
public:
LayerNormBetaGammaBackprop() : PrimitiveC(prim::kPrimLayerNormBetaGammaBackprop->name()) {}
~LayerNormBetaGammaBackprop() = default;
MS_DECLARE_PARENT(LayerNormBetaGammaBackprop, PrimitiveC);
MIND_API_BASE_MEMBER(LayerNormBetaGammaBackprop);
LayerNormBetaGammaBackprop() : BaseOperator("LayerNormBetaGammaBackprop") {}
void Init() const {}
};
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -44,6 +45,7 @@ TypePtr LayerNormXBackpropInferType(const PrimitivePtr &prim, const std::vector<
}
} // namespace
MIND_API_BASE_IMPL(LayerNormXBackprop, PrimitiveC, BaseOperator);
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -20,23 +20,21 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
class MS_CORE_API LayerNormXBackprop : public PrimitiveC {
class MIND_API LayerNormXBackprop : public BaseOperator {
public:
LayerNormXBackprop() : PrimitiveC(prim::kPrimLayerNormXBackprop->name()) {}
~LayerNormXBackprop() = default;
MS_DECLARE_PARENT(LayerNormXBackprop, PrimitiveC);
MIND_API_BASE_MEMBER(LayerNormXBackprop);
LayerNormXBackprop() : BaseOperator("LayerNormXBackprop") {}
void Init() const {}
};
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -142,6 +143,8 @@ ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBaseP
return result_tensor;
}
} // namespace
MIND_API_BASE_IMPL(Abs, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer, AbsInferValue, true);
} // namespace ops
} // namespace mindspore

View File

@ -19,21 +19,18 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
/// \brief Returns absolute value of a tensor element-wise.
/// Refer to Python API @ref mindspore.ops.Abs for more details.
class MS_CORE_API Abs : public PrimitiveC {
class MIND_API Abs : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Abs);
/// \brief Constructor.
Abs() : PrimitiveC(prim::kPrimAbs->name()) { InitIOName({"input_x"}, {"output"}); }
/// \brief Destructor.
~Abs() = default;
MS_DECLARE_PARENT(Abs, PrimitiveC);
Abs() : BaseOperator("Abs") { InitIOName({"input_x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Abs for the inputs.
void Init() const {}
};

View File

@ -22,6 +22,8 @@
#include "ops/accumulate_n_v2.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -83,6 +85,7 @@ TypePtr AccumulateNV2InferType(const PrimitivePtr &prim, const std::vector<Abstr
}
} // namespace
MIND_API_BASE_IMPL(AccumulateNV2, PrimitiveC, BaseOperator);
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,21 +19,19 @@
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAccumulateNV2 = "AccumulateNV2";
class MS_CORE_API AccumulateNV2 : public PrimitiveC {
class MIND_API AccumulateNV2 : public BaseOperator {
public:
AccumulateNV2() : PrimitiveC(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
~AccumulateNV2() = default;
MS_DECLARE_PARENT(AccumulateNV2, PrimitiveC);
MIND_API_BASE_MEMBER(AccumulateNV2);
AccumulateNV2() : BaseOperator(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
};
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAccumulateNV2Ptr = std::shared_ptr<AccumulateNV2>;
} // namespace ops
} // namespace mindspore

View File

@ -15,6 +15,15 @@
*/
#include "ops/acos.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -38,6 +47,7 @@ TypePtr ACosInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
}
} // namespace
MIND_API_BASE_IMPL(ACos, PrimitiveC, BaseOperator);
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,28 +22,23 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameACos = "ACos";
/// \brief Computes arccosine of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.ACos for more details.
class ACos : public PrimitiveC {
class MIND_API ACos : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ACos);
/// \brief Constructor.
ACos() : PrimitiveC(kNameACos) { InitIOName({"x"}, {"y"}); }
/// \brief Destructor.
~ACos() = default;
MS_DECLARE_PARENT(ACos, PrimitiveC);
ACos() : BaseOperator(kNameACos) { InitIOName({"x"}, {"y"}); }
};
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimACosPtr = std::shared_ptr<ACos>;
} // namespace ops

View File

@ -15,6 +15,10 @@
*/
#include "ops/acosh.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -42,6 +46,7 @@ TypePtr AcoshInferType(const PrimitivePtr &primitive, const std::vector<Abstract
}
} // namespace
MIND_API_BASE_IMPL(Acosh, PrimitiveC, BaseOperator);
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,28 +22,23 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAcosh = "Acosh";
/// \brief Computes arccosh of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.Acosh for more details.
class Acosh : public PrimitiveC {
class MIND_API Acosh : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Acosh);
/// \brief Constructor.
Acosh() : PrimitiveC(kNameAcosh) { InitIOName({"x"}, {"y"}); }
/// \brief Destructor.
~Acosh() = default;
MS_DECLARE_PARENT(Acosh, PrimitiveC);
Acosh() : BaseOperator(kNameAcosh) { InitIOName({"x"}, {"y"}); }
};
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAcoshPtr = std::shared_ptr<Acosh>;
} // namespace ops

View File

@ -20,6 +20,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -79,14 +80,18 @@ abstract::TupleShapePtr AdamInferShape(const PrimitivePtr &primitive, const std:
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
}
} // namespace
MIND_API_BASE_IMPL(Adam, PrimitiveC, BaseOperator);
void Adam::Init(const bool use_locking, const bool use_nesterov) {
this->set_use_locking(use_locking);
this->set_use_nesterov(use_nesterov);
}
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, api::MakeValue(use_locking)); }
void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
void Adam::set_use_nesterov(const bool use_nesterov) {
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
}
bool Adam::get_use_locking() const {
auto value_ptr = GetAttr(kUseLocking);

View File

@ -20,22 +20,19 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdam = "Adam";
/// \brief Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
/// Refer to Python API @ref mindspore.ops.Adam for more details.
class MS_CORE_API Adam : public PrimitiveC {
class MIND_API Adam : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Adam);
/// \brief Constructor.
Adam() : PrimitiveC(kNameAdam) {}
/// \brief Destructor.
~Adam() = default;
MS_DECLARE_PARENT(Adam, PrimitiveC);
Adam() : BaseOperator(kNameAdam) {}
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Adam for the inputs.
void Init(const bool use_locking = false, const bool use_nesterov = false);
/// \brief Set use_locking.
@ -51,8 +48,8 @@ class MS_CORE_API Adam : public PrimitiveC {
/// \return use_nesterov.
bool get_use_nesterov() const;
};
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimAdamPtr = std::shared_ptr<Adam>;
} // namespace ops
} // namespace mindspore

View File

@ -21,9 +21,11 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Add, PrimitiveC, BaseOperator);
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -20,28 +20,25 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdd = prim::kAdd;
constexpr auto kNameAdd = "Add";
/// \brief Adds two input tensors element-wise. Refer to Python API @ref mindspore.ops.Add for more details.
class MS_CORE_API Add : public PrimitiveC {
class MIND_API Add : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Add);
/// \brief Constructor.
Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); }
/// \brief Destructor.
~Add() = default;
MS_DECLARE_PARENT(Add, PrimitiveC);
Add() : BaseOperator(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
explicit Add(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "y"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Add for the inputs.
void Init() const {}
};
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -73,6 +74,8 @@ TypePtr AddcdivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
return input_data_type;
}
} // namespace
MIND_API_BASE_IMPL(Addcdiv, PrimitiveC, BaseOperator);
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,23 +19,20 @@
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAddcdiv = "Addcdiv";
class Addcdiv : public PrimitiveC {
class MIND_API Addcdiv : public BaseOperator {
public:
Addcdiv() : PrimitiveC(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
~Addcdiv() = default;
MS_DECLARE_PARENT(Addcdiv, PrimitiveC);
MIND_API_BASE_MEMBER(Addcdiv);
Addcdiv() : BaseOperator(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
};
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAddcdivPtr = std::shared_ptr<Addcdiv>;
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -73,6 +74,8 @@ TypePtr AddcmulInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
return input_data_type;
}
} // namespace
MIND_API_BASE_IMPL(Addcmul, PrimitiveC, BaseOperator);
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,23 +19,20 @@
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAddcmul = "Addcmul";
class Addcmul : public PrimitiveC {
class MIND_API Addcmul : public BaseOperator {
public:
Addcmul() : PrimitiveC(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
~Addcmul() = default;
MS_DECLARE_PARENT(Addcmul, PrimitiveC);
MIND_API_BASE_MEMBER(Addcmul);
Addcmul() : BaseOperator(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
};
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAddcmulPtr = std::shared_ptr<Addcmul>;
} // namespace ops
} // namespace mindspore

View File

@ -16,9 +16,11 @@
#include "ops/adder.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Adder, PrimitiveC, BaseOperator);
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
@ -33,14 +35,16 @@ void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std:
set_format(format);
}
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, MakeValue(in_channel)); }
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, api::MakeValue(in_channel)); }
int64_t Adder::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
return GetValue<int64_t>(value_ptr);
}
void Adder::set_out_channel(const int64_t out_channel) { (void)this->AddAttr(kOutChannel, MakeValue(out_channel)); }
void Adder::set_out_channel(const int64_t out_channel) {
(void)this->AddAttr(kOutChannel, api::MakeValue(out_channel));
}
int64_t Adder::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
@ -48,7 +52,7 @@ int64_t Adder::get_out_channel() const {
}
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
(void)this->AddAttr(kKernelSize, MakeValue(kernel_size));
(void)this->AddAttr(kKernelSize, api::MakeValue(kernel_size));
}
std::vector<int64_t> Adder::get_kernel_size() const {
@ -58,7 +62,7 @@ std::vector<int64_t> Adder::get_kernel_size() const {
void Adder::set_pad_mode(const PadMode &pad_mode) {
int64_t swi = pad_mode;
(void)this->AddAttr(kPadMode, MakeValue(swi));
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
}
PadMode Adder::get_pad_mode() const {
@ -66,28 +70,32 @@ PadMode Adder::get_pad_mode() const {
return PadMode(GetValue<int64_t>(value_ptr));
}
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
std::vector<int64_t> Adder::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { (void)this->AddAttr(kPadList, MakeValue(pad_list)); }
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) {
(void)this->AddAttr(kPadList, api::MakeValue(pad_list));
}
std::vector<int64_t> Adder::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_dilation(const std::vector<int64_t> &dilation) { (void)this->AddAttr(kDilation, MakeValue(dilation)); }
void Adder::set_dilation(const std::vector<int64_t> &dilation) {
(void)this->AddAttr(kDilation, api::MakeValue(dilation));
}
std::vector<int64_t> Adder::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, MakeValue(group)); }
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, api::MakeValue(group)); }
int64_t Adder::get_group() const {
auto value_ptr = GetAttr(kGroup);
@ -96,7 +104,7 @@ int64_t Adder::get_group() const {
void Adder::set_format(const Format &format) {
int64_t swi = format;
(void)this->AddAttr(kFormat, MakeValue(swi));
(void)this->AddAttr(kFormat, api::MakeValue(swi));
}
Format Adder::get_format() const {

View File

@ -21,22 +21,19 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/base/format.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdder = "Adder";
/// \brief All defined All operator prototype of lite.
class MS_CORE_API Adder : public PrimitiveC {
class MIND_API Adder : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Adder);
/// \brief Constructor.
explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {}
/// \brief Destructor.
~Adder() = default;
MS_DECLARE_PARENT(Adder, PrimitiveC);
explicit Adder(const std::string &k_name = kNameAdder) : BaseOperator(k_name) {}
/// \brief Method to init the op's attributes.
///

View File

@ -21,6 +21,8 @@
#include <memory>
#include "ops/addn.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -83,6 +85,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
return elements[0]->BuildType();
}
} // namespace
MIND_API_BASE_IMPL(AddN, PrimitiveC, BaseOperator);
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -18,27 +18,24 @@
#define MINDSPORE_CORE_OPS_ADDN_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAddN = "AddN";
/// \brief Computes addition of all input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.AddN for more details.
class MS_CORE_API AddN : public PrimitiveC {
class MIND_API AddN : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AddN);
/// \brief Constructor.
AddN() : PrimitiveC(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
/// \brief Destructor.
~AddN() = default;
MS_DECLARE_PARENT(AddN, PrimitiveC);
AddN() : BaseOperator(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AddN for the inputs.
void Init() const {}
};
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -17,8 +17,11 @@
#include "ops/affine.h"
#include <vector>
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Affine, PrimitiveC, BaseOperator);
void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a, bool transpose_b) {
this->set_context(contexts);
this->set_output_dim(output_dim);
@ -27,17 +30,17 @@ void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool
}
void Affine::set_context(const std::vector<int64_t> &context) {
(void)this->AddAttr(kAffineContext, MakeValue(context));
(void)this->AddAttr(kAffineContext, api::MakeValue(context));
}
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, MakeValue(output_dim)); }
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, api::MakeValue(output_dim)); }
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
void Affine::set_activation_type(const ActivationType &activation_type) {
(void)this->AddAttr(kActivationType, MakeValue(static_cast<int64_t>(activation_type)));
(void)this->AddAttr(kActivationType, api::MakeValue(static_cast<int64_t>(activation_type)));
}
bool Affine::get_transpose_a() const {

View File

@ -18,25 +18,21 @@
#define MINDSPORE_CORE_OPS_AFFINE_H_
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAffine = "Affine";
constexpr auto kAffineContext = "context";
constexpr auto kAffineOutputDim = "output_dim";
/// \brief Assert defined Affine operator prototype of lite.
class MS_CORE_API Affine : public PrimitiveC {
class MIND_API Affine : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Affine);
/// \brief Constructor.
Affine() : PrimitiveC(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
/// \brief Destructor.
~Affine() = default;
MS_DECLARE_PARENT(Affine, PrimitiveC);
Affine() : BaseOperator(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
/// \brief Method to init the op's attributes.
void Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a = false,
bool transpose_b = false);

View File

@ -17,12 +17,14 @@
#include "ops/all.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(All, PrimitiveC, BaseOperator);
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
int64_t All::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);

View File

@ -16,23 +16,18 @@
#ifndef MINDSPORE_CORE_OPS_ALL_H_
#define MINDSPORE_CORE_OPS_ALL_H_
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAll = "All";
/// \brief All defined All operator prototype of lite.
class MS_CORE_API All : public PrimitiveC {
class MIND_API All : public BaseOperator {
public:
MIND_API_BASE_MEMBER(All);
/// \brief Constructor.
All() : PrimitiveC(kNameAll) {}
/// \brief Destructor.
~All() = default;
MS_DECLARE_PARENT(All, PrimitiveC);
All() : BaseOperator(kNameAll) {}
/// \brief Method to init the op's attributes.
///

View File

@ -17,12 +17,14 @@
#include "ops/all_gather.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(AllGather, PrimitiveC, BaseOperator);
void AllGather::set_group(const string &group) {
std::string g = group;
(void)this->AddAttr(kGroup, MakeValue(g));
(void)this->AddAttr(kGroup, api::MakeValue(g));
}
std::string AllGather::get_group() const {
auto value_ptr = GetAttr(kGroup);
@ -30,7 +32,7 @@ std::string AllGather::get_group() const {
}
void AllGather::set_rank_size(int rank_size) {
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
(void)this->AddAttr(kRankSize, api::MakeValue(static_cast<int64_t>(rank_size)));
}
int AllGather::get_rank_size() const {
auto value_ptr = GetAttr(kRankSize);

View File

@ -20,18 +20,16 @@
#include <string>
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAllGather = "AllGather";
class MS_CORE_API AllGather : public PrimitiveC {
class MIND_API AllGather : public BaseOperator {
public:
AllGather() : PrimitiveC(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
~AllGather() = default;
MS_DECLARE_PARENT(AllGather, PrimitiveC);
MIND_API_BASE_MEMBER(AllGather);
AllGather() : BaseOperator(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
void Init() {}
void set_group(const std::string &format);
std::string get_group() const;

View File

@ -1,164 +1,167 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_ada_max.h"
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto m_shape = input_args[kInputIndex1]->BuildShape();
auto v_shape = input_args[kInputIndex2]->BuildShape();
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>();
auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>();
auto beta1_power_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
auto grad_shape = input_args[kInputIndex8]->BuildShape();
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
// beta1_power,lr,beta1,beta2,epsilon must be scalar
const int64_t kInputShape = 1;
(void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape,
prim_name);
if (beta1_power_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape,
prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name);
if (lr_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta1_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta2_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name);
if (epsilon_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name);
}
// var, m,v and grad must have the same shape
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"m", m_shape});
same_shape_args_map.insert({"v", v_shape});
same_shape_args_map.insert({"grad", grad_shape});
if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) {
if (*m_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*v_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*grad_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape});
}
TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto v_type = input_args[kInputIndex2]->BuildType();
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
auto lr_type = input_args[kInputIndex4]->BuildType();
auto beta1_type = input_args[kInputIndex5]->BuildType();
auto beta2_type = input_args[kInputIndex6]->BuildType();
auto epsilon_type = input_args[kInputIndex7]->BuildType();
auto grad_type = input_args[kInputIndex8]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// m v grad must have the same type as var
std::map<std::string, TypePtr> args;
(void)args.insert({"var_type", var_type});
(void)args.insert({"m_type", m_type});
(void)args.insert({"v_type", v_type});
(void)args.insert({"grad_type", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_beta1_power;
std::map<std::string, TypePtr> args_lr;
std::map<std::string, TypePtr> args_beta1;
std::map<std::string, TypePtr> args_beta2;
std::map<std::string, TypePtr> args_epsilon;
(void)args_beta1_power.insert({"beta1_power_type", beta1_power_type});
(void)args_lr.insert({"lr_type", lr_type});
(void)args_beta1.insert({"beta1_type", beta1_type});
(void)args_beta2.insert({"beta2_type", beta2_type});
(void)args_epsilon.insert({"epsilon_type", epsilon_type});
// beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
}
} // namespace
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = ApplyAdaMaxInferType(primitive, input_args);
auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_ada_max.h"
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto m_shape = input_args[kInputIndex1]->BuildShape();
auto v_shape = input_args[kInputIndex2]->BuildShape();
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>();
auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>();
auto beta1_power_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
auto grad_shape = input_args[kInputIndex8]->BuildShape();
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
// beta1_power,lr,beta1,beta2,epsilon must be scalar
const int64_t kInputShape = 1;
(void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape,
prim_name);
if (beta1_power_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape,
prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name);
if (lr_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta1_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta2_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name);
if (epsilon_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name);
}
// var, m,v and grad must have the same shape
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"m", m_shape});
same_shape_args_map.insert({"v", v_shape});
same_shape_args_map.insert({"grad", grad_shape});
if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) {
if (*m_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*v_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*grad_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape});
}
TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto v_type = input_args[kInputIndex2]->BuildType();
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
auto lr_type = input_args[kInputIndex4]->BuildType();
auto beta1_type = input_args[kInputIndex5]->BuildType();
auto beta2_type = input_args[kInputIndex6]->BuildType();
auto epsilon_type = input_args[kInputIndex7]->BuildType();
auto grad_type = input_args[kInputIndex8]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// m v grad must have the same type as var
std::map<std::string, TypePtr> args;
(void)args.insert({"var_type", var_type});
(void)args.insert({"m_type", m_type});
(void)args.insert({"v_type", v_type});
(void)args.insert({"grad_type", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_beta1_power;
std::map<std::string, TypePtr> args_lr;
std::map<std::string, TypePtr> args_beta1;
std::map<std::string, TypePtr> args_beta2;
std::map<std::string, TypePtr> args_epsilon;
(void)args_beta1_power.insert({"beta1_power_type", beta1_power_type});
(void)args_lr.insert({"lr_type", lr_type});
(void)args_beta1.insert({"beta1_type", beta1_type});
(void)args_beta2.insert({"beta2_type", beta2_type});
(void)args_epsilon.insert({"epsilon_type", epsilon_type});
// beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdaMax, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = ApplyAdaMaxInferType(primitive, input_args);
auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,45 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
class ApplyAdaMax : public PrimitiveC {
public:
ApplyAdaMax() : PrimitiveC(kNameApplyAdaMax) {
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
}
~ApplyAdaMax() = default;
MS_DECLARE_PARENT(ApplyAdaMax, PrimitiveC);
};
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
class MIND_API ApplyAdaMax : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ApplyAdaMax);
ApplyAdaMax() : BaseOperator(kNameApplyAdaMax) {
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
}
};
abstract::AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_

View File

@ -23,6 +23,7 @@
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -121,6 +122,8 @@ TuplePtr ApplyAdadeltaInferType(const PrimitivePtr &primitive, const std::vector
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type, accum_update_type});
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdadelta, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = ApplyAdadeltaInferType(primitive, input_args);

View File

@ -22,23 +22,21 @@
#include <set>
#include <map>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdadelta = "ApplyAdadelta";
class ApplyAdadelta : public PrimitiveC {
class MIND_API ApplyAdadelta : public BaseOperator {
public:
ApplyAdadelta() : PrimitiveC(kNameApplyAdadelta) {
MIND_API_BASE_MEMBER(ApplyAdadelta);
ApplyAdadelta() : BaseOperator(kNameApplyAdadelta) {
InitIOName({"var", "accum", "accum_update", "lr", "rho", "epsilon", "grad"}, {"var", "accum", "accum_update"});
}
~ApplyAdadelta() = default;
MS_DECLARE_PARENT(ApplyAdadelta, PrimitiveC);
};
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyAdadeltaPtr = std::shared_ptr<ApplyAdadelta>;
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,8 @@
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -81,6 +83,7 @@ TuplePtr ApplyAdagradInferType(const PrimitivePtr &primitive, const std::vector<
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdagrad, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,22 +22,20 @@
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdagrad = "ApplyAdagrad";
class ApplyAdagrad : public PrimitiveC {
class MIND_API ApplyAdagrad : public BaseOperator {
public:
ApplyAdagrad() : PrimitiveC(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
~ApplyAdagrad() = default;
MS_DECLARE_PARENT(ApplyAdagrad, PrimitiveC);
MIND_API_BASE_MEMBER(ApplyAdagrad);
ApplyAdagrad() : BaseOperator(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
};
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>;
} // namespace ops

View File

@ -23,10 +23,10 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr ApplyAdagradDAInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -98,6 +98,7 @@ TuplePtr ApplyAdagradDAInferType(const PrimitivePtr &prim, const std::vector<Abs
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdagradDA, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,31 +22,26 @@
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdagradDA = "ApplyAdagradDA";
/// \brief Update var according to the proximal adagrad scheme.
/// Refer to Python API @ref mindspore.ops.ApplyAdagradDA for more details.
class ApplyAdagradDA : public PrimitiveC {
class MIND_API ApplyAdagradDA : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ApplyAdagradDA);
/// \brief Constructor.
ApplyAdagradDA() : PrimitiveC(kNameApplyAdagradDA) {
ApplyAdagradDA() : BaseOperator(kNameApplyAdagradDA) {
InitIOName({"var", "gradient_accumulator", "gradient_squared_accumulator", "grad", "lr", "l1", "l2", "global_step"},
{"var", "gradient_accumulator", "gradient_squared_accumulator"});
}
/// \brief Destructor.
~ApplyAdagradDA() = default;
MS_DECLARE_PARENT(ApplyAdagradDA, PrimitiveC);
};
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -77,6 +78,7 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vector<Abs
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdagradV2, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,23 +22,19 @@
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdagradV2 = "ApplyAdagradV2";
class ApplyAdagradV2 : public PrimitiveC {
class MIND_API ApplyAdagradV2 : public BaseOperator {
public:
ApplyAdagradV2() : PrimitiveC(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
~ApplyAdagradV2() = default;
MS_DECLARE_PARENT(ApplyAdagradV2, PrimitiveC);
MIND_API_BASE_MEMBER(ApplyAdagradV2);
ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
};
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyAdagradV2Ptr = std::shared_ptr<ApplyAdagradV2>;
} // namespace ops
} // namespace mindspore

View File

@ -23,6 +23,8 @@
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -91,6 +93,7 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect
}
} // namespace
MIND_API_BASE_IMPL(ApplyAdamWithAmsgrad, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,24 +19,22 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
class ApplyAdamWithAmsgrad : public PrimitiveC {
class MIND_API ApplyAdamWithAmsgrad : public BaseOperator {
public:
ApplyAdamWithAmsgrad() : PrimitiveC(kNameApplyAdamWithAmsgrad) {
MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad);
ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) {
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
}
~ApplyAdamWithAmsgrad() = default;
MS_DECLARE_PARENT(ApplyAdamWithAmsgrad, PrimitiveC);
};
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
} // namespace ops

View File

@ -21,6 +21,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -116,6 +117,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vector<Abstr
}
} // namespace
MIND_API_BASE_IMPL(ApplyAddSign, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -21,27 +21,23 @@
#include <memory>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAddSign = "ApplyAddSign";
class ApplyAddSign : public PrimitiveC {
class MIND_API ApplyAddSign : public BaseOperator {
public:
ApplyAddSign() : PrimitiveC(kNameApplyAddSign) {
MIND_API_BASE_MEMBER(ApplyAddSign);
ApplyAddSign() : BaseOperator(kNameApplyAddSign) {
InitIOName({"var", "m", "lr", "alpha", "sign_decay", "beta", "grad"}, {"var", "m"});
}
~ApplyAddSign() = default;
MS_DECLARE_PARENT(ApplyAddSign, PrimitiveC);
};
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyAddSignPtr = std::shared_ptr<ApplyAddSign>;
} // namespace ops
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#include "ops/apply_centered_rms_prop.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -106,6 +107,8 @@ TypePtr ApplyCenteredRMSPropInferType(const PrimitivePtr &primitive, const std::
return var_dtype;
}
} // namespace
MIND_API_BASE_IMPL(ApplyCenteredRMSProp, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = ApplyCenteredRMSPropInferType(primitive, input_args);

View File

@ -22,25 +22,23 @@
#include <set>
#include <map>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyCenteredRMSProp = "ApplyCenteredRMSProp";
class ApplyCenteredRMSProp : public PrimitiveC {
class MIND_API ApplyCenteredRMSProp : public BaseOperator {
public:
ApplyCenteredRMSProp() : PrimitiveC(kNameApplyCenteredRMSProp) {
MIND_API_BASE_MEMBER(ApplyCenteredRMSProp);
ApplyCenteredRMSProp() : BaseOperator(kNameApplyCenteredRMSProp) {
InitIOName(
{"var", "mean_gradient", "mean_square", "moment", "grad", "learning_rate", "decay", "momentum", "epsilon"},
{"var", "mean_gradient", "mean_square", "moment"});
}
~ApplyCenteredRMSProp() = default;
MS_DECLARE_PARENT(ApplyCenteredRMSProp, PrimitiveC);
};
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyCenteredRMSPropPtr = std::shared_ptr<ApplyCenteredRMSProp>;
} // namespace ops
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -91,6 +92,8 @@ TypePtr ApplyFtrlInferType(const PrimitivePtr &prim, const std::vector<AbstractB
return var_type;
}
} // namespace
MIND_API_BASE_IMPL(ApplyFtrl, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,24 +22,21 @@
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyFtrl = "ApplyFtrl";
class ApplyFtrl : public PrimitiveC {
class MIND_API ApplyFtrl : public BaseOperator {
public:
ApplyFtrl() : PrimitiveC(kNameApplyFtrl) {
MIND_API_BASE_MEMBER(ApplyFtrl);
ApplyFtrl() : BaseOperator(kNameApplyFtrl) {
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
}
~ApplyFtrl() = default;
MS_DECLARE_PARENT(ApplyFtrl, PrimitiveC);
};
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
} // namespace ops
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -72,6 +73,7 @@ TypePtr ApplyGradientDescentInferType(const PrimitivePtr &prim, const std::vecto
}
} // namespace
MIND_API_BASE_IMPL(ApplyGradientDescent, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,24 +22,20 @@
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyGradientDescent = "ApplyGradientDescent";
class ApplyGradientDescent : public PrimitiveC {
class MIND_API ApplyGradientDescent : public BaseOperator {
public:
ApplyGradientDescent() : PrimitiveC(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
~ApplyGradientDescent() = default;
MS_DECLARE_PARENT(ApplyGradientDescent, PrimitiveC);
MIND_API_BASE_MEMBER(ApplyGradientDescent);
ApplyGradientDescent() : BaseOperator(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
};
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimApplyGradientDescentPtr = std::shared_ptr<ApplyGradientDescent>;
} // namespace ops

View File

@ -22,6 +22,8 @@
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -81,6 +83,7 @@ TuplePtr ApplyKerasMomentumInferType(const PrimitivePtr &prim, const std::vector
}
} // namespace
MIND_API_BASE_IMPL(ApplyKerasMomentum, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,24 +22,22 @@
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
class MS_CORE_API ApplyKerasMomentum : public PrimitiveC {
class MIND_API ApplyKerasMomentum : public BaseOperator {
public:
ApplyKerasMomentum() : PrimitiveC(kNameApplyKerasMomentum) {
MIND_API_BASE_MEMBER(ApplyKerasMomentum);
ApplyKerasMomentum() : BaseOperator(kNameApplyKerasMomentum) {
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
}
~ApplyKerasMomentum() = default;
MS_DECLARE_PARENT(ApplyKerasMomentum, PrimitiveC);
};
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
} // namespace ops

View File

@ -20,6 +20,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -30,15 +31,15 @@ void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const
}
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
(void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
}
void ApplyMomentum::set_use_locking(const bool use_locking) {
(void)this->AddAttr(kUseLocking, MakeValue(use_locking));
(void)this->AddAttr(kUseLocking, api::MakeValue(use_locking));
}
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
(void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
(void)this->AddAttr(kGradientScale, api::MakeValue(gradient_scale));
}
bool ApplyMomentum::get_use_nesterov() const {
@ -102,6 +103,8 @@ TypePtr ApplyMomentumInferType(const PrimitivePtr &primitive, const std::vector<
return v_tensor_type;
}
} // namespace
MIND_API_BASE_IMPL(ApplyMomentum, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = ApplyMomentumInferType(primitive, input_args);

View File

@ -22,24 +22,21 @@
#include <set>
#include <map>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyMomentum = "ApplyMomentum";
/// \brief Optimizer that implements the Momentum algorithm.
/// Refer to Python API @ref mindspore.ops.ApplyMomentum for more details.
class MS_CORE_API ApplyMomentum : public PrimitiveC {
class MIND_API ApplyMomentum : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ApplyMomentum);
/// \brief Constructor.
ApplyMomentum() : PrimitiveC(kNameApplyMomentum) {
ApplyMomentum() : BaseOperator(kNameApplyMomentum) {
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
}
/// \brief Destructor.
~ApplyMomentum() = default;
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ApplyMomentum for the inputs.
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
/// \brief Set use_nesterov.
@ -61,8 +58,8 @@ class MS_CORE_API ApplyMomentum : public PrimitiveC {
/// \return gradient_scale.
float get_gradient_scale() const;
};
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
} // namespace ops
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -108,6 +109,7 @@ TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<Ab
}
} // namespace
MIND_API_BASE_IMPL(ApplyPowerSign, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,23 +19,21 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
class ApplyPowerSign : public PrimitiveC {
class MIND_API ApplyPowerSign : public BaseOperator {
public:
ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) {
MIND_API_BASE_MEMBER(ApplyPowerSign);
ApplyPowerSign() : BaseOperator(kNameApplyPowerSign) {
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
}
~ApplyPowerSign() = default;
MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC);
};
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,8 @@
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -99,6 +101,7 @@ TuplePtr ApplyProximalAdagradInferType(const PrimitivePtr &primitive, const std:
}
} // namespace
MIND_API_BASE_IMPL(ApplyProximalAdagrad, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,24 +22,22 @@
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyProximalAdagrad = "ApplyProximalAdagrad";
class ApplyProximalAdagrad : public PrimitiveC {
class MIND_API ApplyProximalAdagrad : public BaseOperator {
public:
ApplyProximalAdagrad() : PrimitiveC(kNameApplyProximalAdagrad) {
MIND_API_BASE_MEMBER(ApplyProximalAdagrad);
ApplyProximalAdagrad() : BaseOperator(kNameApplyProximalAdagrad) {
InitIOName({"var", "accum", "lr", "l1", "l2", "grad"}, {"var", "accum"});
}
~ApplyProximalAdagrad() = default;
MS_DECLARE_PARENT(ApplyProximalAdagrad, PrimitiveC);
};
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApplyProximalAdagradPtr = std::shared_ptr<ApplyProximalAdagrad>;
} // namespace ops

View File

@ -25,6 +25,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -98,6 +99,7 @@ TypePtr ApplyProximalGradientDescentInferType(const PrimitivePtr &prim,
}
} // namespace
MIND_API_BASE_IMPL(ApplyProximalGradientDescent, PrimitiveC, BaseOperator);
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t input_num = 5;

View File

@ -19,23 +19,22 @@
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyProximalGradientDescent = "ApplyProximalGradientDescent";
class ApplyProximalGradientDescent : public PrimitiveC {
class MIND_API ApplyProximalGradientDescent : public BaseOperator {
public:
ApplyProximalGradientDescent() : PrimitiveC(kNameApplyProximalGradientDescent) {
MIND_API_BASE_MEMBER(ApplyProximalGradientDescent);
ApplyProximalGradientDescent() : BaseOperator(kNameApplyProximalGradientDescent) {
InitIOName({"var", "alpha", "l1", "l2", "delta"}, {"var"});
}
~ApplyProximalGradientDescent() = default;
MS_DECLARE_PARENT(ApplyProximalGradientDescent, PrimitiveC);
};
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -18,6 +18,9 @@
#include <set>
#include <map>
#include <string>
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -59,6 +62,8 @@ TypePtr ApproximateEqualInferType(const PrimitivePtr &prim, const std::vector<Ab
return y_dtype;
}
} // namespace
MIND_API_BASE_IMPL(ApproximateEqual, PrimitiveC, BaseOperator);
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,21 +19,18 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
class ApproximateEqual : public PrimitiveC {
class MIND_API ApproximateEqual : public BaseOperator {
public:
ApproximateEqual() : PrimitiveC(prim::kPrimApproximateEqual->name()) {}
~ApproximateEqual() = default;
MS_DECLARE_PARENT(ApproximateEqual, PrimitiveC);
MIND_API_BASE_MEMBER(ApproximateEqual);
ApproximateEqual() : BaseOperator("ApproximateEqual") {}
void Init() {}
};
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimApproximateEqualPtr = std::shared_ptr<ApproximateEqual>;
} // namespace ops
} // namespace mindspore

View File

@ -15,6 +15,10 @@
*/
#include "ops/arg_max.h"
#include "mindapi/ir/type.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -23,15 +27,17 @@ void ArgMax::Init(const int64_t axis, const TypeId output_type) {
set_output_type(output_type);
}
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
void ArgMax::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
void ArgMax::set_output_type(const TypeId output_type) {
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
}
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
TypeId ArgMax::get_output_type() const {
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
return type_ptr->type_id();
}
MIND_API_BASE_IMPL(ArgMax, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
} // namespace ops
} // namespace mindspore

View File

@ -20,24 +20,21 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/base/type_id.h"
namespace mindspore {
namespace ops {
constexpr auto kNameArgMax = "Argmax";
/// \brief Returns the indices of the maximum value of a tensor across the axis.
/// Refer to Python API @ref mindspore.ops.Argmax for more details.
class MS_CORE_API ArgMax : public PrimitiveC {
class MIND_API ArgMax : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ArgMax);
/// \brief Constructor.
ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); }
explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Destructor.
~ArgMax() = default;
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
ArgMax() : BaseOperator(kNameArgMax) { InitIOName({"x"}, {"output"}); }
explicit ArgMax(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmax for the inputs.
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
/// \brief Set axis.
@ -54,8 +51,8 @@ class MS_CORE_API ArgMax : public PrimitiveC {
/// \return output_type.
TypeId get_output_type() const;
};
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -16,21 +16,28 @@
#include <set>
#include "ops/arg_min.h"
#include "mindapi/ir/type.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(ArgMin, PrimitiveC, BaseOperator);
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
set_axis(axis);
set_output_type(output_type);
}
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
void ArgMin::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
void ArgMin::set_output_type(const TypeId output_type) {
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
}
int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
TypeId ArgMin::get_output_type() const {
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
return type_ptr->type_id();
}

View File

@ -20,24 +20,21 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/base/type_id.h"
namespace mindspore {
namespace ops {
constexpr auto kNameArgMin = "ArgMin";
/// \brief Returns the indices of the minimum value of a tensor across the axis.
/// Refer to Python API @ref mindspore.ops.Argmin for more details.
class MS_CORE_API ArgMin : public PrimitiveC {
class MIND_API ArgMin : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ArgMin);
/// \brief Constructor.
ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); }
explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Destructor.
~ArgMin() = default;
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
ArgMin() : BaseOperator(kNameArgMin) { InitIOName({"x"}, {"output"}); }
explicit ArgMin(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmin for the inputs.
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
/// \brief Set axis.
@ -54,8 +51,8 @@ class MS_CORE_API ArgMin : public PrimitiveC {
/// \return output_type.
TypeId get_output_type() const;
};
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimArgMin = std::shared_ptr<ArgMin>;
} // namespace ops
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -42,6 +43,7 @@ TypePtr AsinInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
}
} // namespace
MIND_API_BASE_IMPL(Asin, PrimitiveC, BaseOperator);
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,30 +22,26 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAsin = "Asin";
/// \brief Computes arcsine of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.Asin for more details.
class MS_CORE_API Asin : public PrimitiveC {
class MIND_API Asin : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Asin);
/// \brief Constructor.
Asin() : PrimitiveC(kNameAsin) { InitIOName({"x"}, {"y"}); }
/// \brief Destructor.
~Asin() = default;
MS_DECLARE_PARENT(Asin, PrimitiveC);
Asin() : BaseOperator(kNameAsin) { InitIOName({"x"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Asin for the inputs.
void Init() const {}
};
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAsinPtr = std::shared_ptr<Asin>;
} // namespace ops

View File

@ -15,6 +15,11 @@
*/
#include "ops/asinh.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -42,6 +47,7 @@ TypePtr AsinhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
}
} // namespace
MIND_API_BASE_IMPL(Asinh, PrimitiveC, BaseOperator);
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -22,29 +22,24 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAsinh = "Asinh";
/// \brief Computes arcsinh of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.Asinh for more details.
class MS_CORE_API Asinh : public PrimitiveC {
class MIND_API Asinh : public BaseOperator {
public:
/// \brief Constructor.
Asinh() : PrimitiveC(kNameAsinh) { InitIOName({"x"}, {"y"}); }
/// \brief Destructor.
~Asinh() = default;
MS_DECLARE_PARENT(Asinh, PrimitiveC);
MIND_API_BASE_MEMBER(Asinh);
Asinh() : BaseOperator(kNameAsinh) { InitIOName({"x"}, {"y"}); }
void Init() {}
};
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAsinhPtr = std::shared_ptr<Asinh>;
} // namespace ops

View File

@ -21,13 +21,16 @@
#include <memory>
#include "ops/assert.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Assert, PrimitiveC, BaseOperator);
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, MakeValue(summarize)); }
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, api::MakeValue(summarize)); }
int64_t Assert::get_summarize() const {
auto value_ptr = GetAttr(kSummarize);

View File

@ -18,23 +18,18 @@
#define MINDSPORE_CORE_OPS_ASSERT_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssert = "Assert";
/// \brief Assert defined Assert operator prototype of lite.
class MS_CORE_API Assert : public PrimitiveC {
class MIND_API Assert : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Assert);
/// \brief Constructor.
Assert() : PrimitiveC(kNameAssert) {}
/// \brief Destructor.
~Assert() = default;
MS_DECLARE_PARENT(Assert, PrimitiveC);
Assert() : BaseOperator(kNameAssert) {}
/// \brief Method to init the op's attributes.
///
@ -52,8 +47,8 @@ class MS_CORE_API Assert : public PrimitiveC {
int64_t get_summarize() const;
};
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -23,9 +23,12 @@
#include "ops/assign.h"
#include "ops/op_utils.h"
#include "ir/dtype/ref.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Assign, PrimitiveC, BaseOperator);
abstract::ShapePtr AssignInferShape(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();

View File

@ -19,21 +19,18 @@
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssign = "Assign";
/// \brief Assigns Parameter with a value. Refer to Python API @ref mindspore.ops.Assign for more details.
class MS_CORE_API Assign : public PrimitiveC {
class MIND_API Assign : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Assign);
/// \brief Constructor.
Assign() : PrimitiveC(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
/// \brief Destructor.
~Assign() = default;
MS_DECLARE_PARENT(Assign, PrimitiveC);
Assign() : BaseOperator(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Assign for the inputs.
void Init() const {}
};

View File

@ -19,6 +19,7 @@
#include "ops/assign_add.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -39,6 +40,8 @@ TypePtr AssignAddInferType(const PrimitivePtr &primitive, const std::vector<Abst
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
}
} // namespace
MIND_API_BASE_IMPL(AssignAdd, PrimitiveC, BaseOperator);
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,27 +19,24 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssignAdd = "AssignAdd";
/// \brief Updates a Parameter by adding a value to it.
/// Refer to Python API @ref mindspore.ops.AssignAdd for more details.
class MS_CORE_API AssignAdd : public PrimitiveC {
class MIND_API AssignAdd : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AssignAdd);
/// \brief Constructor.
AssignAdd() : PrimitiveC(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
/// \brief Destructor.
~AssignAdd() = default;
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
AssignAdd() : BaseOperator(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AssignAdd for the inputs.
void Init() const {}
};
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimAssignAddPtr = std::shared_ptr<AssignAdd>;
} // namespace ops
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <string>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -41,6 +42,7 @@ TypePtr AssignSubInferType(const PrimitivePtr &primitive, const std::vector<Abst
}
} // namespace
MIND_API_BASE_IMPL(AssignSub, PrimitiveC, BaseOperator);
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -19,22 +19,20 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssignSub = "AssignSub";
class AssignSub : public PrimitiveC {
class MIND_API AssignSub : public BaseOperator {
public:
AssignSub() : PrimitiveC(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
~AssignSub() = default;
MS_DECLARE_PARENT(AssignSub, PrimitiveC);
MIND_API_BASE_MEMBER(AssignSub);
AssignSub() : BaseOperator(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
void Init() {}
};
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimAssignSubPtr = std::shared_ptr<AssignSub>;
} // namespace ops
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -52,6 +53,8 @@ TypePtr AtanInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
return x_type;
}
} // namespace
MIND_API_BASE_IMPL(Atan, PrimitiveC, BaseOperator);
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -20,27 +20,24 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAtan = "Atan";
/// \brief Computes the trigonometric inverse tangent of the input element-wise.
/// Refer to Python API @ref mindspore.ops.Atan for more details.
class MS_CORE_API Atan : public PrimitiveC {
class MIND_API Atan : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Atan);
/// \brief Constructor.
Atan() : PrimitiveC(kNameAtan) { InitIOName({"x"}, {"output"}); }
/// \brief Destructor.
~Atan() = default;
MS_DECLARE_PARENT(Atan, PrimitiveC);
Atan() : BaseOperator(kNameAtan) {}
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Atan for the inputs.
void Init() const {}
};
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -58,6 +59,8 @@ TypePtr AtanhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
return x_type;
}
} // namespace
MIND_API_BASE_IMPL(Atanh, PrimitiveC, BaseOperator);
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = AtanhInferType(primitive, input_args);

View File

@ -20,22 +20,20 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAtanh = "Atanh";
class Atanh : public PrimitiveC {
class MIND_API Atanh : public BaseOperator {
public:
Atanh() : PrimitiveC(kNameAtanh) { InitIOName({"x"}, {"output"}); }
~Atanh() = default;
MS_DECLARE_PARENT(Atanh, PrimitiveC);
MIND_API_BASE_MEMBER(Atanh);
Atanh() : BaseOperator(kNameAtanh) { InitIOName({"x"}, {"output"}); }
void Init() {}
};
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimAtanhPtr = std::shared_ptr<Atanh>;
} // namespace ops

View File

@ -16,7 +16,10 @@
*/
#include "ops/attention.h"
#include "ops/primitive_c.h"
#include "mindapi/src/helper.h"
namespace mindspore::ops {
MIND_API_BASE_IMPL(Attention, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
} // namespace mindspore::ops

View File

@ -19,25 +19,23 @@
#include <vector>
#include <string>
#include <memory>
#include "utils/check_convert_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAttention = "Attention";
/// \brief MultiHead-Attention op in MindIR.
class MS_CORE_API Attention : public PrimitiveC {
class MIND_API Attention : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Attention);
/// \brief Constructor.
Attention() : PrimitiveC(kNameAttention) {
Attention() : BaseOperator(kNameAttention) {
InitIOName(
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
{"output"});
}
/// \brief Destructor.
~Attention() override = default;
MS_DECLARE_PARENT(Attention, PrimitiveC);
/// \brief Initialize Attention op.
void Init() const {}
};

View File

@ -23,24 +23,28 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(AudioSpectrogram, PrimitiveC, BaseOperator);
void AudioSpectrogram::set_window_size(const int64_t window_size) {
(void)this->AddAttr(kWindowSize, MakeValue(window_size));
(void)this->AddAttr(kWindowSize, api::MakeValue(window_size));
}
int64_t AudioSpectrogram::get_window_size() const {
auto value_ptr = GetAttr(kWindowSize);
return GetValue<int64_t>(value_ptr);
}
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
int64_t AudioSpectrogram::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<int64_t>(value_ptr);
}
void AudioSpectrogram::set_mag_square(const bool mag_square) { (void)this->AddAttr(kMagSquare, MakeValue(mag_square)); }
void AudioSpectrogram::set_mag_square(const bool mag_square) {
(void)this->AddAttr(kMagSquare, api::MakeValue(mag_square));
}
bool AudioSpectrogram::get_mag_square() const {
auto value_ptr = GetAttr(kMagSquare);
return GetValue<bool>(value_ptr);

View File

@ -20,23 +20,18 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
/// \brief AudioSpectrogram defined AudioSpectrogram operator prototype.
class MS_CORE_API AudioSpectrogram : public PrimitiveC {
class MIND_API AudioSpectrogram : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AudioSpectrogram);
/// \brief Constructor.
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
/// \brief Destructor.
~AudioSpectrogram() = default;
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
AudioSpectrogram() : BaseOperator(kNameAudioSpectrogram) {}
/// \brief Method to init the op's attributes.
///
@ -75,8 +70,8 @@ class MS_CORE_API AudioSpectrogram : public PrimitiveC {
/// \return a boolean value.
bool get_mag_square() const;
};
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -23,35 +23,37 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
int64_t swi = pad_mode;
(void)this->AddAttr(kPadMode, MakeValue(swi));
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
}
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
(void)this->AddAttr(kKernelSize,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
(void)this->AddAttr(
kKernelSize, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
}
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
(void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
(void)this->AddAttr(kStrides,
api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
}
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
void AvgPool::set_format(const Format &format) {
int64_t f = format;
(void)this->AddAttr(kFormat, MakeValue(f));
(void)this->AddAttr(kFormat, api::MakeValue(f));
}
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, api::MakeValue(pad)); }
std::vector<int64_t> AvgPool::get_pad() const {
auto value_ptr = GetAttr(kPad);
@ -60,7 +62,7 @@ std::vector<int64_t> AvgPool::get_pad() const {
void AvgPool::set_round_mode(const RoundMode &round_mode) {
int64_t swi = round_mode;
(void)this->AddAttr(kRoundMode, MakeValue(swi));
(void)this->AddAttr(kRoundMode, api::MakeValue(swi));
}
RoundMode AvgPool::get_round_mode() const {
@ -78,6 +80,7 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
this->set_round_mode(round_mode);
}
MIND_API_BASE_IMPL(AvgPool, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
} // namespace ops
} // namespace mindspore

View File

@ -21,22 +21,20 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/base/format.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAvgPool = "AvgPool";
/// \brief Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool for more details.
class MS_CORE_API AvgPool : public PrimitiveC {
class MIND_API AvgPool : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AvgPool);
/// \brief Constructor.
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Destructor.
~AvgPool() = default;
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
AvgPool() : BaseOperator(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
explicit AvgPool(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AvgPool for the inputs.
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
const PadMode &pad_mode = VALID, const Format &format = NCHW,
@ -80,8 +78,8 @@ class MS_CORE_API AvgPool : public PrimitiveC {
RoundMode get_round_mode() const;
};
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -180,6 +181,7 @@ TypePtr AvgPool3DInferType(const PrimitivePtr &primitive, const std::vector<Abst
}
} // namespace
MIND_API_BASE_IMPL(AvgPool3D, PrimitiveC, BaseOperator);
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(AvgPool3DInferShape(primitive, input_args), AvgPool3DInferType(primitive, input_args));

View File

@ -21,24 +21,21 @@
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
/// \brief 3D Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool3D for more details.
class MS_CORE_API AvgPool3D : public PrimitiveC {
class MIND_API AvgPool3D : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AvgPool3D);
/// \brief Constructor.
AvgPool3D() : PrimitiveC(prim::kPrimAvgPool3D->name()) { InitIOName({"input"}, {"output"}); }
/// \brief Destructor.
~AvgPool3D() = default;
MS_DECLARE_PARENT(AvgPool3D, PrimitiveC);
AvgPool3D() : BaseOperator("AvgPool3D") { InitIOName({"input"}, {"output"}); }
};
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -15,13 +15,18 @@
*/
#include "ops/base_operator.h"
#include "ops/primitive_c.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(BaseOperator, PrimitiveC, api::Primitive);
BaseOperator::BaseOperator(const std::string &name) : api::Primitive(std::make_shared<PrimitiveC>(name)) {}
PrimitiveCPtr BaseOperator::GetPrim() {
PrimitiveCPtr res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
return res;
}
void BaseOperator::InitIOName(const std::vector<std::string> &inputs_name,
const std::vector<std::string> &outputs_name) {
(void)AddAttr("input_names", api::MakeValue(inputs_name));

View File

@ -17,26 +17,36 @@
#ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
#define MINDSPORE_CORE_OPS_BASE_OPERATOR_
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "mindapi/ir/primitive.h"
namespace mindspore {
namespace abstract {
class AnalysisEngine;
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
class AbstractBase;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
} // namespace abstract
} // namespace mindspore
namespace mindspore {
class Primitive;
using PrimitivePtr = std::shared_ptr<Primitive>;
} // namespace mindspore
namespace mindspore {
namespace ops {
class BaseOperator : public api::Primitive {
class PrimitiveC;
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
class MIND_API BaseOperator : public api::Primitive {
public:
MIND_API_BASE_MEMBER(BaseOperator);
explicit BaseOperator(const std::string &name);
~BaseOperator() = default;
PrimitiveCPtr GetPrim();
protected:
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);

View File

@ -23,6 +23,7 @@
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -134,14 +135,15 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac
}
} // namespace
MIND_API_BASE_IMPL(BatchMatmul, PrimitiveC, BaseOperator);
void BatchMatmul::Init(bool transpose_a, bool transpose_b) {
set_transpose_a(transpose_a);
set_transpose_b(transpose_b);
}
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
bool BatchMatmul::get_transpose_a() const {
auto value_ptr = GetAttr(kTransposeA);

View File

@ -18,21 +18,18 @@
#define MINDSPORE_CORE_OPS_BATCH_MATMUL_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
/// \brief Computes matrix multiplication between two tensors by batch.
/// Refer to Python API @ref mindspore.ops.BatchMatmul for more details.
class MS_CORE_API BatchMatmul : public PrimitiveC {
class MIND_API BatchMatmul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(BatchMatmul);
/// \brief Constructor.
BatchMatmul() : PrimitiveC(prim::kPrimBatchMatMul->name()) { InitIOName({"x1", "x2"}, {"output"}); }
/// \brief Destructor.
~BatchMatmul() = default;
MS_DECLARE_PARENT(BatchMatmul, PrimitiveC);
BatchMatmul() : BaseOperator("BatchMatMul") { InitIOName({"x1", "x2"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.BatchMatmul for the inputs.
void Init(bool transpose_a = false, bool transpose_b = false);
/// \brief Set transpose_a.
@ -48,8 +45,8 @@ class MS_CORE_API BatchMatmul : public PrimitiveC {
/// \return transpose_b.
bool get_transpose_b() const;
};
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More