forked from mindspore-Ecosystem/mindspore
!31666 [MS][LITE] new core ops api and lite adapter new api
Merge pull request !31666 from luoyuan/core2
This commit is contained in:
commit
0b6f330d7e
|
@ -24,6 +24,7 @@
|
||||||
#include "utils/shape_utils.h"
|
#include "utils/shape_utils.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#define USE_DEPRECATED_API
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
|
@ -69,8 +69,8 @@ class RegisterStandardPrimitiveEvalHelper {
|
||||||
static auto helper_##name = \
|
static auto helper_##name = \
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
|
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
|
||||||
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
|
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
|
||||||
auto out = std::make_shared<name>(); \
|
name out; \
|
||||||
return out; \
|
return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl()); \
|
||||||
} \
|
} \
|
||||||
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
|
|
|
@ -114,5 +114,10 @@ enum PaddingMode : int64_t {
|
||||||
SYMMETRIC = 2,
|
SYMMETRIC = 2,
|
||||||
MODE_RESERVED = 3,
|
MODE_RESERVED = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum PoolMode : int64_t {
|
||||||
|
MAX_POOLING = 0,
|
||||||
|
MEAN_POOLING = 1,
|
||||||
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_
|
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_
|
||||||
|
|
|
@ -153,5 +153,7 @@ class MIND_API AbstractTuple : public AbstractSequence {
|
||||||
/// \param[in] elements A list of abstracts.
|
/// \param[in] elements A list of abstracts.
|
||||||
explicit AbstractTuple(const AbstractBasePtrList &elements);
|
explicit AbstractTuple(const AbstractBasePtrList &elements);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using AbstractTuplePtr = SharedPtr<AbstractTuple>;
|
||||||
} // namespace mindspore::api
|
} // namespace mindspore::api
|
||||||
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
||||||
|
|
|
@ -47,5 +47,9 @@ using FuncGraphPtr = SharedPtr<FuncGraph>;
|
||||||
|
|
||||||
class FuncGraphManager;
|
class FuncGraphManager;
|
||||||
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
|
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
|
||||||
|
|
||||||
|
class CNode;
|
||||||
|
using CNodePtr = SharedPtr<CNode>;
|
||||||
|
using CNodePtrList = std::vector<CNodePtr>;
|
||||||
} // namespace mindspore::api
|
} // namespace mindspore::api
|
||||||
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -49,6 +50,7 @@ TypePtr LayerNormBetaGammaBackpropInferType(const PrimitivePtr &prim, const std:
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(LayerNormBetaGammaBackprop, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -20,23 +20,22 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class MS_CORE_API LayerNormBetaGammaBackprop : public PrimitiveC {
|
class MIND_API LayerNormBetaGammaBackprop : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
LayerNormBetaGammaBackprop() : PrimitiveC(prim::kPrimLayerNormBetaGammaBackprop->name()) {}
|
MIND_API_BASE_MEMBER(LayerNormBetaGammaBackprop);
|
||||||
~LayerNormBetaGammaBackprop() = default;
|
LayerNormBetaGammaBackprop() : BaseOperator("LayerNormBetaGammaBackprop") {}
|
||||||
MS_DECLARE_PARENT(LayerNormBetaGammaBackprop, PrimitiveC);
|
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePtr &,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -44,6 +45,7 @@ TypePtr LayerNormXBackpropInferType(const PrimitivePtr &prim, const std::vector<
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(LayerNormXBackprop, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -20,23 +20,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class MS_CORE_API LayerNormXBackprop : public PrimitiveC {
|
class MIND_API LayerNormXBackprop : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
LayerNormXBackprop() : PrimitiveC(prim::kPrimLayerNormXBackprop->name()) {}
|
MIND_API_BASE_MEMBER(LayerNormXBackprop);
|
||||||
~LayerNormXBackprop() = default;
|
LayerNormXBackprop() : BaseOperator("LayerNormXBackprop") {}
|
||||||
MS_DECLARE_PARENT(LayerNormXBackprop, PrimitiveC);
|
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -142,6 +143,8 @@ ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
||||||
return result_tensor;
|
return result_tensor;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Abs, PrimitiveC, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer, AbsInferValue, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer, AbsInferValue, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,21 +19,18 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
/// \brief Returns absolute value of a tensor element-wise.
|
/// \brief Returns absolute value of a tensor element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.Abs for more details.
|
/// Refer to Python API @ref mindspore.ops.Abs for more details.
|
||||||
class MS_CORE_API Abs : public PrimitiveC {
|
class MIND_API Abs : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Abs);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Abs() : PrimitiveC(prim::kPrimAbs->name()) { InitIOName({"input_x"}, {"output"}); }
|
Abs() : BaseOperator("Abs") { InitIOName({"input_x"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Abs() = default;
|
|
||||||
MS_DECLARE_PARENT(Abs, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Abs for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Abs for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
|
|
||||||
#include "ops/accumulate_n_v2.h"
|
#include "ops/accumulate_n_v2.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -83,6 +85,7 @@ TypePtr AccumulateNV2InferType(const PrimitivePtr &prim, const std::vector<Abstr
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AccumulateNV2, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,21 +19,19 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAccumulateNV2 = "AccumulateNV2";
|
constexpr auto kNameAccumulateNV2 = "AccumulateNV2";
|
||||||
class MS_CORE_API AccumulateNV2 : public PrimitiveC {
|
class MIND_API AccumulateNV2 : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
AccumulateNV2() : PrimitiveC(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
|
MIND_API_BASE_MEMBER(AccumulateNV2);
|
||||||
~AccumulateNV2() = default;
|
AccumulateNV2() : BaseOperator(kNameAccumulateNV2) { InitIOName({"inputs"}, {"sum"}); }
|
||||||
MS_DECLARE_PARENT(AccumulateNV2, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using PrimAccumulateNV2Ptr = std::shared_ptr<AccumulateNV2>;
|
using PrimAccumulateNV2Ptr = std::shared_ptr<AccumulateNV2>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,6 +15,15 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/acos.h"
|
#include "ops/acos.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -38,6 +47,7 @@ TypePtr ACosInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ACos, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,28 +22,23 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameACos = "ACos";
|
constexpr auto kNameACos = "ACos";
|
||||||
/// \brief Computes arccosine of input tensors element-wise.
|
/// \brief Computes arccosine of input tensors element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.ACos for more details.
|
/// Refer to Python API @ref mindspore.ops.ACos for more details.
|
||||||
class ACos : public PrimitiveC {
|
class MIND_API ACos : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ACos);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
ACos() : PrimitiveC(kNameACos) { InitIOName({"x"}, {"y"}); }
|
ACos() : BaseOperator(kNameACos) { InitIOName({"x"}, {"y"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~ACos() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(ACos, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimACosPtr = std::shared_ptr<ACos>;
|
using PrimACosPtr = std::shared_ptr<ACos>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -15,6 +15,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/acosh.h"
|
#include "ops/acosh.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -42,6 +46,7 @@ TypePtr AcoshInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Acosh, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,28 +22,23 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAcosh = "Acosh";
|
constexpr auto kNameAcosh = "Acosh";
|
||||||
/// \brief Computes arccosh of input tensors element-wise.
|
/// \brief Computes arccosh of input tensors element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.Acosh for more details.
|
/// Refer to Python API @ref mindspore.ops.Acosh for more details.
|
||||||
class Acosh : public PrimitiveC {
|
class MIND_API Acosh : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Acosh);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Acosh() : PrimitiveC(kNameAcosh) { InitIOName({"x"}, {"y"}); }
|
Acosh() : BaseOperator(kNameAcosh) { InitIOName({"x"}, {"y"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Acosh() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Acosh, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimAcoshPtr = std::shared_ptr<Acosh>;
|
using PrimAcoshPtr = std::shared_ptr<Acosh>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -79,14 +80,18 @@ abstract::TupleShapePtr AdamInferShape(const PrimitivePtr &primitive, const std:
|
||||||
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
|
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Adam, PrimitiveC, BaseOperator);
|
||||||
void Adam::Init(const bool use_locking, const bool use_nesterov) {
|
void Adam::Init(const bool use_locking, const bool use_nesterov) {
|
||||||
this->set_use_locking(use_locking);
|
this->set_use_locking(use_locking);
|
||||||
this->set_use_nesterov(use_nesterov);
|
this->set_use_nesterov(use_nesterov);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, api::MakeValue(use_locking)); }
|
||||||
|
|
||||||
void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
void Adam::set_use_nesterov(const bool use_nesterov) {
|
||||||
|
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
|
||||||
|
}
|
||||||
|
|
||||||
bool Adam::get_use_locking() const {
|
bool Adam::get_use_locking() const {
|
||||||
auto value_ptr = GetAttr(kUseLocking);
|
auto value_ptr = GetAttr(kUseLocking);
|
||||||
|
|
|
@ -20,22 +20,19 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAdam = "Adam";
|
constexpr auto kNameAdam = "Adam";
|
||||||
/// \brief Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
|
/// \brief Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
|
||||||
/// Refer to Python API @ref mindspore.ops.Adam for more details.
|
/// Refer to Python API @ref mindspore.ops.Adam for more details.
|
||||||
class MS_CORE_API Adam : public PrimitiveC {
|
class MIND_API Adam : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Adam);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Adam() : PrimitiveC(kNameAdam) {}
|
Adam() : BaseOperator(kNameAdam) {}
|
||||||
/// \brief Destructor.
|
|
||||||
~Adam() = default;
|
|
||||||
MS_DECLARE_PARENT(Adam, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Adam for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Adam for the inputs.
|
||||||
void Init(const bool use_locking = false, const bool use_nesterov = false);
|
void Init(const bool use_locking = false, const bool use_nesterov = false);
|
||||||
/// \brief Set use_locking.
|
/// \brief Set use_locking.
|
||||||
|
@ -51,8 +48,8 @@ class MS_CORE_API Adam : public PrimitiveC {
|
||||||
/// \return use_nesterov.
|
/// \return use_nesterov.
|
||||||
bool get_use_nesterov() const;
|
bool get_use_nesterov() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimAdamPtr = std::shared_ptr<Adam>;
|
using kPrimAdamPtr = std::shared_ptr<Adam>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,9 +21,11 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(Add, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -20,28 +20,25 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAdd = prim::kAdd;
|
constexpr auto kNameAdd = "Add";
|
||||||
/// \brief Adds two input tensors element-wise. Refer to Python API @ref mindspore.ops.Add for more details.
|
/// \brief Adds two input tensors element-wise. Refer to Python API @ref mindspore.ops.Add for more details.
|
||||||
class MS_CORE_API Add : public PrimitiveC {
|
class MIND_API Add : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Add);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
|
Add() : BaseOperator(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
|
||||||
explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); }
|
explicit Add(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "y"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Add() = default;
|
|
||||||
MS_DECLARE_PARENT(Add, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Add for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Add for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -73,6 +74,8 @@ TypePtr AddcdivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
||||||
return input_data_type;
|
return input_data_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Addcdiv, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,23 +19,20 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAddcdiv = "Addcdiv";
|
constexpr auto kNameAddcdiv = "Addcdiv";
|
||||||
class Addcdiv : public PrimitiveC {
|
class MIND_API Addcdiv : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
Addcdiv() : PrimitiveC(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
MIND_API_BASE_MEMBER(Addcdiv);
|
||||||
~Addcdiv() = default;
|
Addcdiv() : BaseOperator(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||||
MS_DECLARE_PARENT(Addcdiv, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using PrimAddcdivPtr = std::shared_ptr<Addcdiv>;
|
using PrimAddcdivPtr = std::shared_ptr<Addcdiv>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -73,6 +74,8 @@ TypePtr AddcmulInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
||||||
return input_data_type;
|
return input_data_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Addcmul, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,23 +19,20 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAddcmul = "Addcmul";
|
constexpr auto kNameAddcmul = "Addcmul";
|
||||||
class Addcmul : public PrimitiveC {
|
class MIND_API Addcmul : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
Addcmul() : PrimitiveC(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
MIND_API_BASE_MEMBER(Addcmul);
|
||||||
~Addcmul() = default;
|
Addcmul() : BaseOperator(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||||
MS_DECLARE_PARENT(Addcmul, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using PrimAddcmulPtr = std::shared_ptr<Addcmul>;
|
using PrimAddcmulPtr = std::shared_ptr<Addcmul>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
#include "ops/adder.h"
|
#include "ops/adder.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(Adder, PrimitiveC, BaseOperator);
|
||||||
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||||
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
|
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
|
||||||
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
|
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
|
||||||
|
@ -33,14 +35,16 @@ void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std:
|
||||||
set_format(format);
|
set_format(format);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, MakeValue(in_channel)); }
|
void Adder::set_in_channel(const int64_t in_channel) { (void)this->AddAttr(kInChannel, api::MakeValue(in_channel)); }
|
||||||
|
|
||||||
int64_t Adder::get_in_channel() const {
|
int64_t Adder::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_out_channel(const int64_t out_channel) { (void)this->AddAttr(kOutChannel, MakeValue(out_channel)); }
|
void Adder::set_out_channel(const int64_t out_channel) {
|
||||||
|
(void)this->AddAttr(kOutChannel, api::MakeValue(out_channel));
|
||||||
|
}
|
||||||
|
|
||||||
int64_t Adder::get_out_channel() const {
|
int64_t Adder::get_out_channel() const {
|
||||||
auto value_ptr = GetAttr(kOutChannel);
|
auto value_ptr = GetAttr(kOutChannel);
|
||||||
|
@ -48,7 +52,7 @@ int64_t Adder::get_out_channel() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||||
(void)this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
(void)this->AddAttr(kKernelSize, api::MakeValue(kernel_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Adder::get_kernel_size() const {
|
std::vector<int64_t> Adder::get_kernel_size() const {
|
||||||
|
@ -58,7 +62,7 @@ std::vector<int64_t> Adder::get_kernel_size() const {
|
||||||
|
|
||||||
void Adder::set_pad_mode(const PadMode &pad_mode) {
|
void Adder::set_pad_mode(const PadMode &pad_mode) {
|
||||||
int64_t swi = pad_mode;
|
int64_t swi = pad_mode;
|
||||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
PadMode Adder::get_pad_mode() const {
|
PadMode Adder::get_pad_mode() const {
|
||||||
|
@ -66,28 +70,32 @@ PadMode Adder::get_pad_mode() const {
|
||||||
return PadMode(GetValue<int64_t>(value_ptr));
|
return PadMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
|
void Adder::set_stride(const std::vector<int64_t> &stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
|
||||||
|
|
||||||
std::vector<int64_t> Adder::get_stride() const {
|
std::vector<int64_t> Adder::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { (void)this->AddAttr(kPadList, MakeValue(pad_list)); }
|
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||||
|
(void)this->AddAttr(kPadList, api::MakeValue(pad_list));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Adder::get_pad_list() const {
|
std::vector<int64_t> Adder::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_dilation(const std::vector<int64_t> &dilation) { (void)this->AddAttr(kDilation, MakeValue(dilation)); }
|
void Adder::set_dilation(const std::vector<int64_t> &dilation) {
|
||||||
|
(void)this->AddAttr(kDilation, api::MakeValue(dilation));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Adder::get_dilation() const {
|
std::vector<int64_t> Adder::get_dilation() const {
|
||||||
auto value_ptr = GetAttr(kDilation);
|
auto value_ptr = GetAttr(kDilation);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, MakeValue(group)); }
|
void Adder::set_group(const int64_t group) { (void)this->AddAttr(kGroup, api::MakeValue(group)); }
|
||||||
|
|
||||||
int64_t Adder::get_group() const {
|
int64_t Adder::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
@ -96,7 +104,7 @@ int64_t Adder::get_group() const {
|
||||||
|
|
||||||
void Adder::set_format(const Format &format) {
|
void Adder::set_format(const Format &format) {
|
||||||
int64_t swi = format;
|
int64_t swi = format;
|
||||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
(void)this->AddAttr(kFormat, api::MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
Format Adder::get_format() const {
|
Format Adder::get_format() const {
|
||||||
|
|
|
@ -21,22 +21,19 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/format.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAdder = "Adder";
|
constexpr auto kNameAdder = "Adder";
|
||||||
/// \brief All defined All operator prototype of lite.
|
/// \brief All defined All operator prototype of lite.
|
||||||
class MS_CORE_API Adder : public PrimitiveC {
|
class MIND_API Adder : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Adder);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {}
|
explicit Adder(const std::string &k_name = kNameAdder) : BaseOperator(k_name) {}
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~Adder() = default;
|
|
||||||
MS_DECLARE_PARENT(Adder, PrimitiveC);
|
|
||||||
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
///
|
///
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/addn.h"
|
#include "ops/addn.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -83,6 +85,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
||||||
return elements[0]->BuildType();
|
return elements[0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AddN, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -18,27 +18,24 @@
|
||||||
#define MINDSPORE_CORE_OPS_ADDN_H_
|
#define MINDSPORE_CORE_OPS_ADDN_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAddN = "AddN";
|
constexpr auto kNameAddN = "AddN";
|
||||||
/// \brief Computes addition of all input tensors element-wise.
|
/// \brief Computes addition of all input tensors element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.AddN for more details.
|
/// Refer to Python API @ref mindspore.ops.AddN for more details.
|
||||||
class MS_CORE_API AddN : public PrimitiveC {
|
class MIND_API AddN : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(AddN);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
AddN() : PrimitiveC(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
|
AddN() : BaseOperator(kNameAddN) { InitIOName({"inputs"}, {"sum"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~AddN() = default;
|
|
||||||
MS_DECLARE_PARENT(AddN, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AddN for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AddN for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,11 @@
|
||||||
#include "ops/affine.h"
|
#include "ops/affine.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(Affine, PrimitiveC, BaseOperator);
|
||||||
void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a, bool transpose_b) {
|
void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a, bool transpose_b) {
|
||||||
this->set_context(contexts);
|
this->set_context(contexts);
|
||||||
this->set_output_dim(output_dim);
|
this->set_output_dim(output_dim);
|
||||||
|
@ -27,17 +30,17 @@ void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool
|
||||||
}
|
}
|
||||||
|
|
||||||
void Affine::set_context(const std::vector<int64_t> &context) {
|
void Affine::set_context(const std::vector<int64_t> &context) {
|
||||||
(void)this->AddAttr(kAffineContext, MakeValue(context));
|
(void)this->AddAttr(kAffineContext, api::MakeValue(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, MakeValue(output_dim)); }
|
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, api::MakeValue(output_dim)); }
|
||||||
|
|
||||||
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
|
||||||
|
|
||||||
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
|
||||||
|
|
||||||
void Affine::set_activation_type(const ActivationType &activation_type) {
|
void Affine::set_activation_type(const ActivationType &activation_type) {
|
||||||
(void)this->AddAttr(kActivationType, MakeValue(static_cast<int64_t>(activation_type)));
|
(void)this->AddAttr(kActivationType, api::MakeValue(static_cast<int64_t>(activation_type)));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Affine::get_transpose_a() const {
|
bool Affine::get_transpose_a() const {
|
||||||
|
|
|
@ -18,25 +18,21 @@
|
||||||
#define MINDSPORE_CORE_OPS_AFFINE_H_
|
#define MINDSPORE_CORE_OPS_AFFINE_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
constexpr auto kNameAffine = "Affine";
|
constexpr auto kNameAffine = "Affine";
|
||||||
constexpr auto kAffineContext = "context";
|
constexpr auto kAffineContext = "context";
|
||||||
constexpr auto kAffineOutputDim = "output_dim";
|
constexpr auto kAffineOutputDim = "output_dim";
|
||||||
|
|
||||||
/// \brief Assert defined Affine operator prototype of lite.
|
/// \brief Assert defined Affine operator prototype of lite.
|
||||||
class MS_CORE_API Affine : public PrimitiveC {
|
class MIND_API Affine : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Affine);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Affine() : PrimitiveC(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
|
Affine() : BaseOperator(kNameAffine) { InitIOName({"x1", "x2"}, {"outputs"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Affine() = default;
|
|
||||||
MS_DECLARE_PARENT(Affine, PrimitiveC);
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
void Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a = false,
|
void Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool transpose_a = false,
|
||||||
bool transpose_b = false);
|
bool transpose_b = false);
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
#include "ops/all.h"
|
#include "ops/all.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(All, PrimitiveC, BaseOperator);
|
||||||
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
|
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
|
||||||
|
|
||||||
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
|
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
|
||||||
|
|
||||||
int64_t All::get_keep_dims() const {
|
int64_t All::get_keep_dims() const {
|
||||||
auto value_ptr = GetAttr(kKeepDims);
|
auto value_ptr = GetAttr(kKeepDims);
|
||||||
|
|
|
@ -16,23 +16,18 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_ALL_H_
|
#ifndef MINDSPORE_CORE_OPS_ALL_H_
|
||||||
#define MINDSPORE_CORE_OPS_ALL_H_
|
#define MINDSPORE_CORE_OPS_ALL_H_
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAll = "All";
|
constexpr auto kNameAll = "All";
|
||||||
/// \brief All defined All operator prototype of lite.
|
/// \brief All defined All operator prototype of lite.
|
||||||
class MS_CORE_API All : public PrimitiveC {
|
class MIND_API All : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(All);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
All() : PrimitiveC(kNameAll) {}
|
All() : BaseOperator(kNameAll) {}
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~All() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(All, PrimitiveC);
|
|
||||||
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
///
|
///
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
#include "ops/all_gather.h"
|
#include "ops/all_gather.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(AllGather, PrimitiveC, BaseOperator);
|
||||||
void AllGather::set_group(const string &group) {
|
void AllGather::set_group(const string &group) {
|
||||||
std::string g = group;
|
std::string g = group;
|
||||||
(void)this->AddAttr(kGroup, MakeValue(g));
|
(void)this->AddAttr(kGroup, api::MakeValue(g));
|
||||||
}
|
}
|
||||||
std::string AllGather::get_group() const {
|
std::string AllGather::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
@ -30,7 +32,7 @@ std::string AllGather::get_group() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllGather::set_rank_size(int rank_size) {
|
void AllGather::set_rank_size(int rank_size) {
|
||||||
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
|
(void)this->AddAttr(kRankSize, api::MakeValue(static_cast<int64_t>(rank_size)));
|
||||||
}
|
}
|
||||||
int AllGather::get_rank_size() const {
|
int AllGather::get_rank_size() const {
|
||||||
auto value_ptr = GetAttr(kRankSize);
|
auto value_ptr = GetAttr(kRankSize);
|
||||||
|
|
|
@ -20,18 +20,16 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAllGather = "AllGather";
|
constexpr auto kNameAllGather = "AllGather";
|
||||||
class MS_CORE_API AllGather : public PrimitiveC {
|
class MIND_API AllGather : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
AllGather() : PrimitiveC(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
|
MIND_API_BASE_MEMBER(AllGather);
|
||||||
~AllGather() = default;
|
AllGather() : BaseOperator(kNameAllGather) { InitIOName({"input_x"}, {"output"}); }
|
||||||
MS_DECLARE_PARENT(AllGather, PrimitiveC);
|
|
||||||
void Init() {}
|
void Init() {}
|
||||||
void set_group(const std::string &format);
|
void set_group(const std::string &format);
|
||||||
std::string get_group() const;
|
std::string get_group() const;
|
||||||
|
|
|
@ -1,164 +1,167 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/apply_ada_max.h"
|
#include "ops/apply_ada_max.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
namespace mindspore {
|
#include "mindapi/src/helper.h"
|
||||||
namespace ops {
|
|
||||||
namespace {
|
namespace mindspore {
|
||||||
abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive,
|
namespace ops {
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
namespace {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive,
|
||||||
const int64_t kInputNum = 9;
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
primitive->name());
|
const int64_t kInputNum = 9;
|
||||||
for (const auto &item : input_args) {
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
primitive->name());
|
||||||
}
|
for (const auto &item : input_args) {
|
||||||
auto prim_name = primitive->name();
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
auto var_shape = input_args[kInputIndex0]->BuildShape();
|
}
|
||||||
auto m_shape = input_args[kInputIndex1]->BuildShape();
|
auto prim_name = primitive->name();
|
||||||
auto v_shape = input_args[kInputIndex2]->BuildShape();
|
auto var_shape = input_args[kInputIndex0]->BuildShape();
|
||||||
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
|
auto m_shape = input_args[kInputIndex1]->BuildShape();
|
||||||
auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>();
|
auto v_shape = input_args[kInputIndex2]->BuildShape();
|
||||||
auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>();
|
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
|
||||||
auto beta1_power_shape =
|
auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>();
|
||||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>();
|
||||||
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
auto beta1_power_shape =
|
||||||
auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||||
auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||||
auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
|
auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||||
auto grad_shape = input_args[kInputIndex8]->BuildShape();
|
auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
||||||
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
|
auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
|
||||||
// beta1_power,lr,beta1,beta2,epsilon must be scalar
|
auto grad_shape = input_args[kInputIndex8]->BuildShape();
|
||||||
const int64_t kInputShape = 1;
|
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape,
|
// beta1_power,lr,beta1,beta2,epsilon must be scalar
|
||||||
prim_name);
|
const int64_t kInputShape = 1;
|
||||||
if (beta1_power_shape.size() == 1) {
|
(void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape,
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape,
|
prim_name);
|
||||||
prim_name);
|
if (beta1_power_shape.size() == 1) {
|
||||||
}
|
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape,
|
||||||
(void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name);
|
prim_name);
|
||||||
if (lr_shape.size() == 1) {
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name);
|
||||||
}
|
if (lr_shape.size() == 1) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name);
|
||||||
if (beta1_shape.size() == 1) {
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name);
|
||||||
}
|
if (beta1_shape.size() == 1) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name);
|
||||||
if (beta2_shape.size() == 1) {
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name);
|
||||||
}
|
if (beta2_shape.size() == 1) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name);
|
||||||
if (epsilon_shape.size() == 1) {
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name);
|
||||||
}
|
if (epsilon_shape.size() == 1) {
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name);
|
||||||
// var, m,v and grad must have the same shape
|
}
|
||||||
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
|
||||||
same_shape_args_map.insert({"m", m_shape});
|
// var, m,v and grad must have the same shape
|
||||||
same_shape_args_map.insert({"v", v_shape});
|
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||||
same_shape_args_map.insert({"grad", grad_shape});
|
same_shape_args_map.insert({"m", m_shape});
|
||||||
if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) {
|
same_shape_args_map.insert({"v", v_shape});
|
||||||
if (*m_shape != *var_shape) {
|
same_shape_args_map.insert({"grad", grad_shape});
|
||||||
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString()
|
if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) {
|
||||||
<< " are not consistent with var shape " << var_shape->ToString();
|
if (*m_shape != *var_shape) {
|
||||||
}
|
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString()
|
||||||
}
|
<< " are not consistent with var shape " << var_shape->ToString();
|
||||||
if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
|
}
|
||||||
if (*v_shape != *var_shape) {
|
}
|
||||||
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString()
|
if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
|
||||||
<< " are not consistent with var shape " << var_shape->ToString();
|
if (*v_shape != *var_shape) {
|
||||||
}
|
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString()
|
||||||
}
|
<< " are not consistent with var shape " << var_shape->ToString();
|
||||||
if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
|
}
|
||||||
if (*grad_shape != *var_shape) {
|
}
|
||||||
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString()
|
if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
|
||||||
<< " are not consistent with var shape " << var_shape->ToString();
|
if (*grad_shape != *var_shape) {
|
||||||
}
|
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString()
|
||||||
}
|
<< " are not consistent with var shape " << var_shape->ToString();
|
||||||
|
}
|
||||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape});
|
}
|
||||||
}
|
|
||||||
|
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape});
|
||||||
TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
}
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
auto prim_name = prim->name();
|
TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
const int64_t kInputNum = 9;
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
auto prim_name = prim->name();
|
||||||
prim_name);
|
const int64_t kInputNum = 9;
|
||||||
for (const auto &item : input_args) {
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
prim_name);
|
||||||
}
|
for (const auto &item : input_args) {
|
||||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
auto m_type = input_args[kInputIndex1]->BuildType();
|
}
|
||||||
auto v_type = input_args[kInputIndex2]->BuildType();
|
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||||
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
|
auto m_type = input_args[kInputIndex1]->BuildType();
|
||||||
auto lr_type = input_args[kInputIndex4]->BuildType();
|
auto v_type = input_args[kInputIndex2]->BuildType();
|
||||||
auto beta1_type = input_args[kInputIndex5]->BuildType();
|
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
|
||||||
auto beta2_type = input_args[kInputIndex6]->BuildType();
|
auto lr_type = input_args[kInputIndex4]->BuildType();
|
||||||
auto epsilon_type = input_args[kInputIndex7]->BuildType();
|
auto beta1_type = input_args[kInputIndex5]->BuildType();
|
||||||
auto grad_type = input_args[kInputIndex8]->BuildType();
|
auto beta2_type = input_args[kInputIndex6]->BuildType();
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
auto epsilon_type = input_args[kInputIndex7]->BuildType();
|
||||||
// m v grad must have the same type as var
|
auto grad_type = input_args[kInputIndex8]->BuildType();
|
||||||
std::map<std::string, TypePtr> args;
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
(void)args.insert({"var_type", var_type});
|
// m v grad must have the same type as var
|
||||||
(void)args.insert({"m_type", m_type});
|
std::map<std::string, TypePtr> args;
|
||||||
(void)args.insert({"v_type", v_type});
|
(void)args.insert({"var_type", var_type});
|
||||||
(void)args.insert({"grad_type", grad_type});
|
(void)args.insert({"m_type", m_type});
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
(void)args.insert({"v_type", v_type});
|
||||||
|
(void)args.insert({"grad_type", grad_type});
|
||||||
std::map<std::string, TypePtr> args_beta1_power;
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||||
std::map<std::string, TypePtr> args_lr;
|
|
||||||
std::map<std::string, TypePtr> args_beta1;
|
std::map<std::string, TypePtr> args_beta1_power;
|
||||||
std::map<std::string, TypePtr> args_beta2;
|
std::map<std::string, TypePtr> args_lr;
|
||||||
std::map<std::string, TypePtr> args_epsilon;
|
std::map<std::string, TypePtr> args_beta1;
|
||||||
|
std::map<std::string, TypePtr> args_beta2;
|
||||||
(void)args_beta1_power.insert({"beta1_power_type", beta1_power_type});
|
std::map<std::string, TypePtr> args_epsilon;
|
||||||
(void)args_lr.insert({"lr_type", lr_type});
|
|
||||||
(void)args_beta1.insert({"beta1_type", beta1_type});
|
(void)args_beta1_power.insert({"beta1_power_type", beta1_power_type});
|
||||||
(void)args_beta2.insert({"beta2_type", beta2_type});
|
(void)args_lr.insert({"lr_type", lr_type});
|
||||||
(void)args_epsilon.insert({"epsilon_type", epsilon_type});
|
(void)args_beta1.insert({"beta1_type", beta1_type});
|
||||||
|
(void)args_beta2.insert({"beta2_type", beta2_type});
|
||||||
// beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type
|
(void)args_epsilon.insert({"epsilon_type", epsilon_type});
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name);
|
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
|
// beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name);
|
||||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name);
|
||||||
}
|
|
||||||
} // namespace
|
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
|
||||||
|
}
|
||||||
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
} // namespace
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MIND_API_BASE_IMPL(ApplyAdaMax, PrimitiveC, BaseOperator);
|
||||||
auto infer_type = ApplyAdaMaxInferType(primitive, input_args);
|
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args);
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
}
|
auto infer_type = ApplyAdaMaxInferType(primitive, input_args);
|
||||||
|
auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args);
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
} // namespace ops
|
}
|
||||||
} // namespace mindspore
|
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,45 +1,43 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
#ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
||||||
#define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
#define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
namespace mindspore {
|
||||||
namespace mindspore {
|
namespace ops {
|
||||||
namespace ops {
|
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
|
||||||
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
|
class MIND_API ApplyAdaMax : public BaseOperator {
|
||||||
class ApplyAdaMax : public PrimitiveC {
|
public:
|
||||||
public:
|
MIND_API_BASE_MEMBER(ApplyAdaMax);
|
||||||
ApplyAdaMax() : PrimitiveC(kNameApplyAdaMax) {
|
ApplyAdaMax() : BaseOperator(kNameApplyAdaMax) {
|
||||||
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
|
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
|
||||||
}
|
}
|
||||||
~ApplyAdaMax() = default;
|
};
|
||||||
MS_DECLARE_PARENT(ApplyAdaMax, PrimitiveC);
|
abstract::AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
};
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
|
||||||
|
} // namespace ops
|
||||||
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
|
} // namespace mindspore
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
#endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -121,6 +122,8 @@ TuplePtr ApplyAdadeltaInferType(const PrimitivePtr &primitive, const std::vector
|
||||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type, accum_update_type});
|
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type, accum_update_type});
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAdadelta, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto infer_type = ApplyAdadeltaInferType(primitive, input_args);
|
auto infer_type = ApplyAdadeltaInferType(primitive, input_args);
|
||||||
|
|
|
@ -22,23 +22,21 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAdadelta = "ApplyAdadelta";
|
constexpr auto kNameApplyAdadelta = "ApplyAdadelta";
|
||||||
class ApplyAdadelta : public PrimitiveC {
|
class MIND_API ApplyAdadelta : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyAdadelta() : PrimitiveC(kNameApplyAdadelta) {
|
MIND_API_BASE_MEMBER(ApplyAdadelta);
|
||||||
|
ApplyAdadelta() : BaseOperator(kNameApplyAdadelta) {
|
||||||
InitIOName({"var", "accum", "accum_update", "lr", "rho", "epsilon", "grad"}, {"var", "accum", "accum_update"});
|
InitIOName({"var", "accum", "accum_update", "lr", "rho", "epsilon", "grad"}, {"var", "accum", "accum_update"});
|
||||||
}
|
}
|
||||||
~ApplyAdadelta() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyAdadelta, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyAdadeltaPtr = std::shared_ptr<ApplyAdadelta>;
|
using kPrimApplyAdadeltaPtr = std::shared_ptr<ApplyAdadelta>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -81,6 +83,7 @@ TuplePtr ApplyAdagradInferType(const PrimitivePtr &primitive, const std::vector<
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAdagrad, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,22 +22,20 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAdagrad = "ApplyAdagrad";
|
constexpr auto kNameApplyAdagrad = "ApplyAdagrad";
|
||||||
class ApplyAdagrad : public PrimitiveC {
|
class MIND_API ApplyAdagrad : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyAdagrad() : PrimitiveC(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
MIND_API_BASE_MEMBER(ApplyAdagrad);
|
||||||
~ApplyAdagrad() = default;
|
ApplyAdagrad() : BaseOperator(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||||
MS_DECLARE_PARENT(ApplyAdagrad, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>;
|
using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -23,10 +23,10 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
abstract::TupleShapePtr ApplyAdagradDAInferShape(const PrimitivePtr &primitive,
|
abstract::TupleShapePtr ApplyAdagradDAInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
@ -98,6 +98,7 @@ TuplePtr ApplyAdagradDAInferType(const PrimitivePtr &prim, const std::vector<Abs
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAdagradDA, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,31 +22,26 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAdagradDA = "ApplyAdagradDA";
|
constexpr auto kNameApplyAdagradDA = "ApplyAdagradDA";
|
||||||
/// \brief Update var according to the proximal adagrad scheme.
|
/// \brief Update var according to the proximal adagrad scheme.
|
||||||
/// Refer to Python API @ref mindspore.ops.ApplyAdagradDA for more details.
|
/// Refer to Python API @ref mindspore.ops.ApplyAdagradDA for more details.
|
||||||
class ApplyAdagradDA : public PrimitiveC {
|
class MIND_API ApplyAdagradDA : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ApplyAdagradDA);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
ApplyAdagradDA() : PrimitiveC(kNameApplyAdagradDA) {
|
ApplyAdagradDA() : BaseOperator(kNameApplyAdagradDA) {
|
||||||
InitIOName({"var", "gradient_accumulator", "gradient_squared_accumulator", "grad", "lr", "l1", "l2", "global_step"},
|
InitIOName({"var", "gradient_accumulator", "gradient_squared_accumulator", "grad", "lr", "l1", "l2", "global_step"},
|
||||||
{"var", "gradient_accumulator", "gradient_squared_accumulator"});
|
{"var", "gradient_accumulator", "gradient_squared_accumulator"});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~ApplyAdagradDA() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(ApplyAdagradDA, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -77,6 +78,7 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vector<Abs
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAdagradV2, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,23 +22,19 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAdagradV2 = "ApplyAdagradV2";
|
constexpr auto kNameApplyAdagradV2 = "ApplyAdagradV2";
|
||||||
class ApplyAdagradV2 : public PrimitiveC {
|
class MIND_API ApplyAdagradV2 : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyAdagradV2() : PrimitiveC(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
MIND_API_BASE_MEMBER(ApplyAdagradV2);
|
||||||
|
ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||||
~ApplyAdagradV2() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(ApplyAdagradV2, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAdagradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyAdagradV2Ptr = std::shared_ptr<ApplyAdagradV2>;
|
using kPrimApplyAdagradV2Ptr = std::shared_ptr<ApplyAdagradV2>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -91,6 +93,7 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAdamWithAmsgrad, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,24 +19,22 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
|
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
|
||||||
class ApplyAdamWithAmsgrad : public PrimitiveC {
|
class MIND_API ApplyAdamWithAmsgrad : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyAdamWithAmsgrad() : PrimitiveC(kNameApplyAdamWithAmsgrad) {
|
MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad);
|
||||||
|
ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) {
|
||||||
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
|
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
|
||||||
}
|
}
|
||||||
~ApplyAdamWithAmsgrad() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyAdamWithAmsgrad, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
|
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -116,6 +117,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vector<Abstr
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyAddSign, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -21,27 +21,23 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyAddSign = "ApplyAddSign";
|
constexpr auto kNameApplyAddSign = "ApplyAddSign";
|
||||||
|
|
||||||
class ApplyAddSign : public PrimitiveC {
|
class MIND_API ApplyAddSign : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyAddSign() : PrimitiveC(kNameApplyAddSign) {
|
MIND_API_BASE_MEMBER(ApplyAddSign);
|
||||||
|
ApplyAddSign() : BaseOperator(kNameApplyAddSign) {
|
||||||
InitIOName({"var", "m", "lr", "alpha", "sign_decay", "beta", "grad"}, {"var", "m"});
|
InitIOName({"var", "m", "lr", "alpha", "sign_decay", "beta", "grad"}, {"var", "m"});
|
||||||
}
|
}
|
||||||
|
|
||||||
~ApplyAddSign() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(ApplyAddSign, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyAddSignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyAddSignPtr = std::shared_ptr<ApplyAddSign>;
|
using kPrimApplyAddSignPtr = std::shared_ptr<ApplyAddSign>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "ops/apply_centered_rms_prop.h"
|
#include "ops/apply_centered_rms_prop.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -106,6 +107,8 @@ TypePtr ApplyCenteredRMSPropInferType(const PrimitivePtr &primitive, const std::
|
||||||
return var_dtype;
|
return var_dtype;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyCenteredRMSProp, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto infer_type = ApplyCenteredRMSPropInferType(primitive, input_args);
|
auto infer_type = ApplyCenteredRMSPropInferType(primitive, input_args);
|
||||||
|
|
|
@ -22,25 +22,23 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyCenteredRMSProp = "ApplyCenteredRMSProp";
|
constexpr auto kNameApplyCenteredRMSProp = "ApplyCenteredRMSProp";
|
||||||
class ApplyCenteredRMSProp : public PrimitiveC {
|
class MIND_API ApplyCenteredRMSProp : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyCenteredRMSProp() : PrimitiveC(kNameApplyCenteredRMSProp) {
|
MIND_API_BASE_MEMBER(ApplyCenteredRMSProp);
|
||||||
|
ApplyCenteredRMSProp() : BaseOperator(kNameApplyCenteredRMSProp) {
|
||||||
InitIOName(
|
InitIOName(
|
||||||
{"var", "mean_gradient", "mean_square", "moment", "grad", "learning_rate", "decay", "momentum", "epsilon"},
|
{"var", "mean_gradient", "mean_square", "moment", "grad", "learning_rate", "decay", "momentum", "epsilon"},
|
||||||
{"var", "mean_gradient", "mean_square", "moment"});
|
{"var", "mean_gradient", "mean_square", "moment"});
|
||||||
}
|
}
|
||||||
~ApplyCenteredRMSProp() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyCenteredRMSProp, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyCenteredRMSPropPtr = std::shared_ptr<ApplyCenteredRMSProp>;
|
using kPrimApplyCenteredRMSPropPtr = std::shared_ptr<ApplyCenteredRMSProp>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -91,6 +92,8 @@ TypePtr ApplyFtrlInferType(const PrimitivePtr &prim, const std::vector<AbstractB
|
||||||
return var_type;
|
return var_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyFtrl, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,24 +22,21 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyFtrl = "ApplyFtrl";
|
constexpr auto kNameApplyFtrl = "ApplyFtrl";
|
||||||
class ApplyFtrl : public PrimitiveC {
|
class MIND_API ApplyFtrl : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyFtrl() : PrimitiveC(kNameApplyFtrl) {
|
MIND_API_BASE_MEMBER(ApplyFtrl);
|
||||||
|
ApplyFtrl() : BaseOperator(kNameApplyFtrl) {
|
||||||
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
|
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
|
||||||
}
|
}
|
||||||
|
|
||||||
~ApplyFtrl() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyFtrl, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
|
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -72,6 +73,7 @@ TypePtr ApplyGradientDescentInferType(const PrimitivePtr &prim, const std::vecto
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyGradientDescent, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,24 +22,20 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyGradientDescent = "ApplyGradientDescent";
|
constexpr auto kNameApplyGradientDescent = "ApplyGradientDescent";
|
||||||
class ApplyGradientDescent : public PrimitiveC {
|
class MIND_API ApplyGradientDescent : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyGradientDescent() : PrimitiveC(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
|
MIND_API_BASE_MEMBER(ApplyGradientDescent);
|
||||||
|
ApplyGradientDescent() : BaseOperator(kNameApplyGradientDescent) { InitIOName({"var", "alpha", "delta"}, {"var"}); }
|
||||||
~ApplyGradientDescent() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(ApplyGradientDescent, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimApplyGradientDescentPtr = std::shared_ptr<ApplyGradientDescent>;
|
using PrimApplyGradientDescentPtr = std::shared_ptr<ApplyGradientDescent>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -81,6 +83,7 @@ TuplePtr ApplyKerasMomentumInferType(const PrimitivePtr &prim, const std::vector
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyKerasMomentum, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,24 +22,22 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
|
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
|
||||||
class MS_CORE_API ApplyKerasMomentum : public PrimitiveC {
|
class MIND_API ApplyKerasMomentum : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyKerasMomentum() : PrimitiveC(kNameApplyKerasMomentum) {
|
MIND_API_BASE_MEMBER(ApplyKerasMomentum);
|
||||||
|
ApplyKerasMomentum() : BaseOperator(kNameApplyKerasMomentum) {
|
||||||
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
||||||
}
|
}
|
||||||
~ApplyKerasMomentum() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyKerasMomentum, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
|
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -30,15 +31,15 @@ void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const
|
||||||
}
|
}
|
||||||
|
|
||||||
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
|
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
|
||||||
(void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
|
(void)this->AddAttr(kUseNesterov, api::MakeValue(use_nesterov));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ApplyMomentum::set_use_locking(const bool use_locking) {
|
void ApplyMomentum::set_use_locking(const bool use_locking) {
|
||||||
(void)this->AddAttr(kUseLocking, MakeValue(use_locking));
|
(void)this->AddAttr(kUseLocking, api::MakeValue(use_locking));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
|
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
|
||||||
(void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
(void)this->AddAttr(kGradientScale, api::MakeValue(gradient_scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ApplyMomentum::get_use_nesterov() const {
|
bool ApplyMomentum::get_use_nesterov() const {
|
||||||
|
@ -102,6 +103,8 @@ TypePtr ApplyMomentumInferType(const PrimitivePtr &primitive, const std::vector<
|
||||||
return v_tensor_type;
|
return v_tensor_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyMomentum, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto infer_type = ApplyMomentumInferType(primitive, input_args);
|
auto infer_type = ApplyMomentumInferType(primitive, input_args);
|
||||||
|
|
|
@ -22,24 +22,21 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyMomentum = "ApplyMomentum";
|
constexpr auto kNameApplyMomentum = "ApplyMomentum";
|
||||||
/// \brief Optimizer that implements the Momentum algorithm.
|
/// \brief Optimizer that implements the Momentum algorithm.
|
||||||
/// Refer to Python API @ref mindspore.ops.ApplyMomentum for more details.
|
/// Refer to Python API @ref mindspore.ops.ApplyMomentum for more details.
|
||||||
class MS_CORE_API ApplyMomentum : public PrimitiveC {
|
class MIND_API ApplyMomentum : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ApplyMomentum);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
ApplyMomentum() : PrimitiveC(kNameApplyMomentum) {
|
ApplyMomentum() : BaseOperator(kNameApplyMomentum) {
|
||||||
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
||||||
}
|
}
|
||||||
/// \brief Destructor.
|
|
||||||
~ApplyMomentum() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ApplyMomentum for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ApplyMomentum for the inputs.
|
||||||
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
|
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
|
||||||
/// \brief Set use_nesterov.
|
/// \brief Set use_nesterov.
|
||||||
|
@ -61,8 +58,8 @@ class MS_CORE_API ApplyMomentum : public PrimitiveC {
|
||||||
/// \return gradient_scale.
|
/// \return gradient_scale.
|
||||||
float get_gradient_scale() const;
|
float get_gradient_scale() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
|
using kPrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -108,6 +109,7 @@ TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<Ab
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyPowerSign, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,23 +19,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
|
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
|
||||||
class ApplyPowerSign : public PrimitiveC {
|
class MIND_API ApplyPowerSign : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) {
|
MIND_API_BASE_MEMBER(ApplyPowerSign);
|
||||||
|
ApplyPowerSign() : BaseOperator(kNameApplyPowerSign) {
|
||||||
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
|
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
|
||||||
}
|
}
|
||||||
~ApplyPowerSign() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
|
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -99,6 +101,7 @@ TuplePtr ApplyProximalAdagradInferType(const PrimitivePtr &primitive, const std:
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyProximalAdagrad, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,24 +22,22 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyProximalAdagrad = "ApplyProximalAdagrad";
|
constexpr auto kNameApplyProximalAdagrad = "ApplyProximalAdagrad";
|
||||||
class ApplyProximalAdagrad : public PrimitiveC {
|
class MIND_API ApplyProximalAdagrad : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyProximalAdagrad() : PrimitiveC(kNameApplyProximalAdagrad) {
|
MIND_API_BASE_MEMBER(ApplyProximalAdagrad);
|
||||||
|
ApplyProximalAdagrad() : BaseOperator(kNameApplyProximalAdagrad) {
|
||||||
InitIOName({"var", "accum", "lr", "l1", "l2", "grad"}, {"var", "accum"});
|
InitIOName({"var", "accum", "lr", "l1", "l2", "grad"}, {"var", "accum"});
|
||||||
}
|
}
|
||||||
~ApplyProximalAdagrad() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyProximalAdagrad, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyProximalAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using kPrimApplyProximalAdagradPtr = std::shared_ptr<ApplyProximalAdagrad>;
|
using kPrimApplyProximalAdagradPtr = std::shared_ptr<ApplyProximalAdagrad>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -98,6 +99,7 @@ TypePtr ApplyProximalGradientDescentInferType(const PrimitivePtr &prim,
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApplyProximalGradientDescent, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
const int64_t input_num = 5;
|
const int64_t input_num = 5;
|
||||||
|
|
|
@ -19,23 +19,22 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameApplyProximalGradientDescent = "ApplyProximalGradientDescent";
|
constexpr auto kNameApplyProximalGradientDescent = "ApplyProximalGradientDescent";
|
||||||
class ApplyProximalGradientDescent : public PrimitiveC {
|
class MIND_API ApplyProximalGradientDescent : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApplyProximalGradientDescent() : PrimitiveC(kNameApplyProximalGradientDescent) {
|
MIND_API_BASE_MEMBER(ApplyProximalGradientDescent);
|
||||||
|
ApplyProximalGradientDescent() : BaseOperator(kNameApplyProximalGradientDescent) {
|
||||||
InitIOName({"var", "alpha", "l1", "l2", "delta"}, {"var"});
|
InitIOName({"var", "alpha", "l1", "l2", "delta"}, {"var"});
|
||||||
}
|
}
|
||||||
~ApplyProximalGradientDescent() = default;
|
|
||||||
MS_DECLARE_PARENT(ApplyProximalGradientDescent, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,9 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -59,6 +62,8 @@ TypePtr ApproximateEqualInferType(const PrimitivePtr &prim, const std::vector<Ab
|
||||||
return y_dtype;
|
return y_dtype;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(ApproximateEqual, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,21 +19,18 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class ApproximateEqual : public PrimitiveC {
|
class MIND_API ApproximateEqual : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
ApproximateEqual() : PrimitiveC(prim::kPrimApproximateEqual->name()) {}
|
MIND_API_BASE_MEMBER(ApproximateEqual);
|
||||||
~ApproximateEqual() = default;
|
ApproximateEqual() : BaseOperator("ApproximateEqual") {}
|
||||||
MS_DECLARE_PARENT(ApproximateEqual, PrimitiveC);
|
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ApproximateEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimApproximateEqualPtr = std::shared_ptr<ApproximateEqual>;
|
using kPrimApproximateEqualPtr = std::shared_ptr<ApproximateEqual>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,6 +15,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/arg_max.h"
|
#include "ops/arg_max.h"
|
||||||
|
#include "mindapi/ir/type.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -23,15 +27,17 @@ void ArgMax::Init(const int64_t axis, const TypeId output_type) {
|
||||||
set_output_type(output_type);
|
set_output_type(output_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
void ArgMax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||||
void ArgMax::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
void ArgMax::set_output_type(const TypeId output_type) {
|
||||||
|
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
|
||||||
|
}
|
||||||
|
|
||||||
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||||
TypeId ArgMax::get_output_type() const {
|
TypeId ArgMax::get_output_type() const {
|
||||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
|
||||||
return type_ptr->type_id();
|
return type_ptr->type_id();
|
||||||
}
|
}
|
||||||
|
MIND_API_BASE_IMPL(ArgMax, PrimitiveC, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
|
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,24 +20,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/type_id.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameArgMax = "Argmax";
|
constexpr auto kNameArgMax = "Argmax";
|
||||||
/// \brief Returns the indices of the maximum value of a tensor across the axis.
|
/// \brief Returns the indices of the maximum value of a tensor across the axis.
|
||||||
/// Refer to Python API @ref mindspore.ops.Argmax for more details.
|
/// Refer to Python API @ref mindspore.ops.Argmax for more details.
|
||||||
class MS_CORE_API ArgMax : public PrimitiveC {
|
class MIND_API ArgMax : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ArgMax);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); }
|
ArgMax() : BaseOperator(kNameArgMax) { InitIOName({"x"}, {"output"}); }
|
||||||
explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
explicit ArgMax(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~ArgMax() = default;
|
|
||||||
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmax for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmax for the inputs.
|
||||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||||
/// \brief Set axis.
|
/// \brief Set axis.
|
||||||
|
@ -54,8 +51,8 @@ class MS_CORE_API ArgMax : public PrimitiveC {
|
||||||
/// \return output_type.
|
/// \return output_type.
|
||||||
TypeId get_output_type() const;
|
TypeId get_output_type() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -16,21 +16,28 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include "ops/arg_min.h"
|
#include "ops/arg_min.h"
|
||||||
|
#include "mindapi/ir/type.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(ArgMin, PrimitiveC, BaseOperator);
|
||||||
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
|
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
|
||||||
set_axis(axis);
|
set_axis(axis);
|
||||||
set_output_type(output_type);
|
set_output_type(output_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
void ArgMin::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||||
void ArgMin::set_output_type(const TypeId output_type) { (void)this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
void ArgMin::set_output_type(const TypeId output_type) {
|
||||||
|
(void)this->AddAttr(kOutputType, api::Type::GetType(output_type));
|
||||||
|
}
|
||||||
|
|
||||||
int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||||
|
|
||||||
TypeId ArgMin::get_output_type() const {
|
TypeId ArgMin::get_output_type() const {
|
||||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
auto type_ptr = GetAttr(kOutputType)->cast<api::TensorTypePtr>()->element();
|
||||||
return type_ptr->type_id();
|
return type_ptr->type_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,24 +20,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "ops/op_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/type_id.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameArgMin = "ArgMin";
|
constexpr auto kNameArgMin = "ArgMin";
|
||||||
/// \brief Returns the indices of the minimum value of a tensor across the axis.
|
/// \brief Returns the indices of the minimum value of a tensor across the axis.
|
||||||
/// Refer to Python API @ref mindspore.ops.Argmin for more details.
|
/// Refer to Python API @ref mindspore.ops.Argmin for more details.
|
||||||
class MS_CORE_API ArgMin : public PrimitiveC {
|
class MIND_API ArgMin : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ArgMin);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); }
|
ArgMin() : BaseOperator(kNameArgMin) { InitIOName({"x"}, {"output"}); }
|
||||||
explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
explicit ArgMin(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~ArgMin() = default;
|
|
||||||
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmin for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Argmin for the inputs.
|
||||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||||
/// \brief Set axis.
|
/// \brief Set axis.
|
||||||
|
@ -54,8 +51,8 @@ class MS_CORE_API ArgMin : public PrimitiveC {
|
||||||
/// \return output_type.
|
/// \return output_type.
|
||||||
TypeId get_output_type() const;
|
TypeId get_output_type() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using PrimArgMin = std::shared_ptr<ArgMin>;
|
using PrimArgMin = std::shared_ptr<ArgMin>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
#include "abstract/param_validator.h"
|
#include "abstract/param_validator.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -42,6 +43,7 @@ TypePtr AsinInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Asin, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,30 +22,26 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAsin = "Asin";
|
constexpr auto kNameAsin = "Asin";
|
||||||
/// \brief Computes arcsine of input tensors element-wise.
|
/// \brief Computes arcsine of input tensors element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.Asin for more details.
|
/// Refer to Python API @ref mindspore.ops.Asin for more details.
|
||||||
class MS_CORE_API Asin : public PrimitiveC {
|
class MIND_API Asin : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Asin);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Asin() : PrimitiveC(kNameAsin) { InitIOName({"x"}, {"y"}); }
|
Asin() : BaseOperator(kNameAsin) { InitIOName({"x"}, {"y"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Asin() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Asin, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Asin for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Asin for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimAsinPtr = std::shared_ptr<Asin>;
|
using PrimAsinPtr = std::shared_ptr<Asin>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -15,6 +15,11 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/asinh.h"
|
#include "ops/asinh.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -42,6 +47,7 @@ TypePtr AsinhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Asinh, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -22,29 +22,24 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAsinh = "Asinh";
|
constexpr auto kNameAsinh = "Asinh";
|
||||||
/// \brief Computes arcsinh of input tensors element-wise.
|
/// \brief Computes arcsinh of input tensors element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.Asinh for more details.
|
/// Refer to Python API @ref mindspore.ops.Asinh for more details.
|
||||||
class MS_CORE_API Asinh : public PrimitiveC {
|
class MIND_API Asinh : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor.
|
MIND_API_BASE_MEMBER(Asinh);
|
||||||
Asinh() : PrimitiveC(kNameAsinh) { InitIOName({"x"}, {"y"}); }
|
Asinh() : BaseOperator(kNameAsinh) { InitIOName({"x"}, {"y"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Asinh() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Asinh, PrimitiveC);
|
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimAsinhPtr = std::shared_ptr<Asinh>;
|
using PrimAsinhPtr = std::shared_ptr<Asinh>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -21,13 +21,16 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/assert.h"
|
#include "ops/assert.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(Assert, PrimitiveC, BaseOperator);
|
||||||
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
|
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
|
||||||
|
|
||||||
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, MakeValue(summarize)); }
|
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, api::MakeValue(summarize)); }
|
||||||
|
|
||||||
int64_t Assert::get_summarize() const {
|
int64_t Assert::get_summarize() const {
|
||||||
auto value_ptr = GetAttr(kSummarize);
|
auto value_ptr = GetAttr(kSummarize);
|
||||||
|
|
|
@ -18,23 +18,18 @@
|
||||||
#define MINDSPORE_CORE_OPS_ASSERT_H_
|
#define MINDSPORE_CORE_OPS_ASSERT_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAssert = "Assert";
|
constexpr auto kNameAssert = "Assert";
|
||||||
/// \brief Assert defined Assert operator prototype of lite.
|
/// \brief Assert defined Assert operator prototype of lite.
|
||||||
class MS_CORE_API Assert : public PrimitiveC {
|
class MIND_API Assert : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Assert);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Assert() : PrimitiveC(kNameAssert) {}
|
Assert() : BaseOperator(kNameAssert) {}
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~Assert() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Assert, PrimitiveC);
|
|
||||||
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
///
|
///
|
||||||
|
@ -52,8 +47,8 @@ class MS_CORE_API Assert : public PrimitiveC {
|
||||||
int64_t get_summarize() const;
|
int64_t get_summarize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,12 @@
|
||||||
#include "ops/assign.h"
|
#include "ops/assign.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "ir/dtype/ref.h"
|
#include "ir/dtype/ref.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(Assign, PrimitiveC, BaseOperator);
|
||||||
abstract::ShapePtr AssignInferShape(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr AssignInferShape(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto prim_name = prim->name();
|
auto prim_name = prim->name();
|
||||||
|
|
|
@ -19,21 +19,18 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAssign = "Assign";
|
constexpr auto kNameAssign = "Assign";
|
||||||
/// \brief Assigns Parameter with a value. Refer to Python API @ref mindspore.ops.Assign for more details.
|
/// \brief Assigns Parameter with a value. Refer to Python API @ref mindspore.ops.Assign for more details.
|
||||||
class MS_CORE_API Assign : public PrimitiveC {
|
class MIND_API Assign : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Assign);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Assign() : PrimitiveC(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
|
Assign() : BaseOperator(kNameAssign) { InitIOName({"ref", "value"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~Assign() = default;
|
|
||||||
MS_DECLARE_PARENT(Assign, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Assign for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Assign for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "ops/assign_add.h"
|
#include "ops/assign_add.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -39,6 +40,8 @@ TypePtr AssignAddInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AssignAdd, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,27 +19,24 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAssignAdd = "AssignAdd";
|
constexpr auto kNameAssignAdd = "AssignAdd";
|
||||||
/// \brief Updates a Parameter by adding a value to it.
|
/// \brief Updates a Parameter by adding a value to it.
|
||||||
/// Refer to Python API @ref mindspore.ops.AssignAdd for more details.
|
/// Refer to Python API @ref mindspore.ops.AssignAdd for more details.
|
||||||
class MS_CORE_API AssignAdd : public PrimitiveC {
|
class MIND_API AssignAdd : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(AssignAdd);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
AssignAdd() : PrimitiveC(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
|
AssignAdd() : BaseOperator(kNameAssignAdd) { InitIOName({"ref", "value"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~AssignAdd() = default;
|
|
||||||
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AssignAdd for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AssignAdd for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimAssignAddPtr = std::shared_ptr<AssignAdd>;
|
using kPrimAssignAddPtr = std::shared_ptr<AssignAdd>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -41,6 +42,7 @@ TypePtr AssignSubInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AssignSub, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -19,22 +19,20 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAssignSub = "AssignSub";
|
constexpr auto kNameAssignSub = "AssignSub";
|
||||||
class AssignSub : public PrimitiveC {
|
class MIND_API AssignSub : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
AssignSub() : PrimitiveC(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
|
MIND_API_BASE_MEMBER(AssignSub);
|
||||||
~AssignSub() = default;
|
AssignSub() : BaseOperator(kNameAssignSub) { InitIOName({"val", "value"}, {"val"}); }
|
||||||
MS_DECLARE_PARENT(AssignSub, PrimitiveC);
|
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AssignSubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
using kPrimAssignSubPtr = std::shared_ptr<AssignSub>;
|
using kPrimAssignSubPtr = std::shared_ptr<AssignSub>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -52,6 +53,8 @@ TypePtr AtanInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
||||||
return x_type;
|
return x_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Atan, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
|
@ -20,27 +20,24 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAtan = "Atan";
|
constexpr auto kNameAtan = "Atan";
|
||||||
/// \brief Computes the trigonometric inverse tangent of the input element-wise.
|
/// \brief Computes the trigonometric inverse tangent of the input element-wise.
|
||||||
/// Refer to Python API @ref mindspore.ops.Atan for more details.
|
/// Refer to Python API @ref mindspore.ops.Atan for more details.
|
||||||
class MS_CORE_API Atan : public PrimitiveC {
|
class MIND_API Atan : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Atan);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Atan() : PrimitiveC(kNameAtan) { InitIOName({"x"}, {"output"}); }
|
Atan() : BaseOperator(kNameAtan) {}
|
||||||
/// \brief Destructor.
|
|
||||||
~Atan() = default;
|
|
||||||
MS_DECLARE_PARENT(Atan, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Atan for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Atan for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -58,6 +59,8 @@ TypePtr AtanhInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
||||||
return x_type;
|
return x_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(Atanh, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto type = AtanhInferType(primitive, input_args);
|
auto type = AtanhInferType(primitive, input_args);
|
||||||
|
|
|
@ -20,22 +20,20 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAtanh = "Atanh";
|
constexpr auto kNameAtanh = "Atanh";
|
||||||
class Atanh : public PrimitiveC {
|
class MIND_API Atanh : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
Atanh() : PrimitiveC(kNameAtanh) { InitIOName({"x"}, {"output"}); }
|
MIND_API_BASE_MEMBER(Atanh);
|
||||||
~Atanh() = default;
|
Atanh() : BaseOperator(kNameAtanh) { InitIOName({"x"}, {"output"}); }
|
||||||
MS_DECLARE_PARENT(Atanh, PrimitiveC);
|
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AtanhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
using PrimAtanhPtr = std::shared_ptr<Atanh>;
|
using PrimAtanhPtr = std::shared_ptr<Atanh>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -16,7 +16,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/attention.h"
|
#include "ops/attention.h"
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore::ops {
|
namespace mindspore::ops {
|
||||||
|
MIND_API_BASE_IMPL(Attention, PrimitiveC, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||||
} // namespace mindspore::ops
|
} // namespace mindspore::ops
|
||||||
|
|
|
@ -19,25 +19,23 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAttention = "Attention";
|
constexpr auto kNameAttention = "Attention";
|
||||||
/// \brief MultiHead-Attention op in MindIR.
|
/// \brief MultiHead-Attention op in MindIR.
|
||||||
class MS_CORE_API Attention : public PrimitiveC {
|
class MIND_API Attention : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(Attention);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
Attention() : PrimitiveC(kNameAttention) {
|
Attention() : BaseOperator(kNameAttention) {
|
||||||
InitIOName(
|
InitIOName(
|
||||||
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
|
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
|
||||||
{"output"});
|
{"output"});
|
||||||
}
|
}
|
||||||
/// \brief Destructor.
|
|
||||||
~Attention() override = default;
|
|
||||||
MS_DECLARE_PARENT(Attention, PrimitiveC);
|
|
||||||
/// \brief Initialize Attention op.
|
/// \brief Initialize Attention op.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
|
@ -23,24 +23,28 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(AudioSpectrogram, PrimitiveC, BaseOperator);
|
||||||
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
||||||
(void)this->AddAttr(kWindowSize, MakeValue(window_size));
|
(void)this->AddAttr(kWindowSize, api::MakeValue(window_size));
|
||||||
}
|
}
|
||||||
int64_t AudioSpectrogram::get_window_size() const {
|
int64_t AudioSpectrogram::get_window_size() const {
|
||||||
auto value_ptr = GetAttr(kWindowSize);
|
auto value_ptr = GetAttr(kWindowSize);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, MakeValue(stride)); }
|
void AudioSpectrogram::set_stride(const int64_t stride) { (void)this->AddAttr(kStride, api::MakeValue(stride)); }
|
||||||
int64_t AudioSpectrogram::get_stride() const {
|
int64_t AudioSpectrogram::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AudioSpectrogram::set_mag_square(const bool mag_square) { (void)this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
void AudioSpectrogram::set_mag_square(const bool mag_square) {
|
||||||
|
(void)this->AddAttr(kMagSquare, api::MakeValue(mag_square));
|
||||||
|
}
|
||||||
bool AudioSpectrogram::get_mag_square() const {
|
bool AudioSpectrogram::get_mag_square() const {
|
||||||
auto value_ptr = GetAttr(kMagSquare);
|
auto value_ptr = GetAttr(kMagSquare);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
|
|
|
@ -20,23 +20,18 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
|
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
|
||||||
/// \brief AudioSpectrogram defined AudioSpectrogram operator prototype.
|
/// \brief AudioSpectrogram defined AudioSpectrogram operator prototype.
|
||||||
class MS_CORE_API AudioSpectrogram : public PrimitiveC {
|
class MIND_API AudioSpectrogram : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(AudioSpectrogram);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
|
AudioSpectrogram() : BaseOperator(kNameAudioSpectrogram) {}
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~AudioSpectrogram() = default;
|
|
||||||
|
|
||||||
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
|
|
||||||
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
///
|
///
|
||||||
|
@ -75,8 +70,8 @@ class MS_CORE_API AudioSpectrogram : public PrimitiveC {
|
||||||
/// \return a boolean value.
|
/// \return a boolean value.
|
||||||
bool get_mag_square() const;
|
bool get_mag_square() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -23,35 +23,37 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
|
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
|
||||||
int64_t swi = pad_mode;
|
int64_t swi = pad_mode;
|
||||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
(void)this->AddAttr(kPadMode, api::MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
||||||
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||||
(void)this->AddAttr(kKernelSize,
|
(void)this->AddAttr(
|
||||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
kKernelSize, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
|
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
|
||||||
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
||||||
(void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
(void)this->AddAttr(kStrides,
|
||||||
|
api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
|
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
|
||||||
|
|
||||||
void AvgPool::set_format(const Format &format) {
|
void AvgPool::set_format(const Format &format) {
|
||||||
int64_t f = format;
|
int64_t f = format;
|
||||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
(void)this->AddAttr(kFormat, api::MakeValue(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||||
|
|
||||||
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
|
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, api::MakeValue(pad)); }
|
||||||
|
|
||||||
std::vector<int64_t> AvgPool::get_pad() const {
|
std::vector<int64_t> AvgPool::get_pad() const {
|
||||||
auto value_ptr = GetAttr(kPad);
|
auto value_ptr = GetAttr(kPad);
|
||||||
|
@ -60,7 +62,7 @@ std::vector<int64_t> AvgPool::get_pad() const {
|
||||||
|
|
||||||
void AvgPool::set_round_mode(const RoundMode &round_mode) {
|
void AvgPool::set_round_mode(const RoundMode &round_mode) {
|
||||||
int64_t swi = round_mode;
|
int64_t swi = round_mode;
|
||||||
(void)this->AddAttr(kRoundMode, MakeValue(swi));
|
(void)this->AddAttr(kRoundMode, api::MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
RoundMode AvgPool::get_round_mode() const {
|
RoundMode AvgPool::get_round_mode() const {
|
||||||
|
@ -78,6 +80,7 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
|
||||||
this->set_round_mode(round_mode);
|
this->set_round_mode(round_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AvgPool, PrimitiveC, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,22 +21,20 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "mindapi/base/format.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameAvgPool = "AvgPool";
|
constexpr auto kNameAvgPool = "AvgPool";
|
||||||
/// \brief Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool for more details.
|
/// \brief Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool for more details.
|
||||||
class MS_CORE_API AvgPool : public PrimitiveC {
|
class MIND_API AvgPool : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(AvgPool);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
AvgPool() : BaseOperator(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
||||||
explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
explicit AvgPool(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~AvgPool() = default;
|
|
||||||
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AvgPool for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.AvgPool for the inputs.
|
||||||
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
|
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
|
||||||
const PadMode &pad_mode = VALID, const Format &format = NCHW,
|
const PadMode &pad_mode = VALID, const Format &format = NCHW,
|
||||||
|
@ -80,8 +78,8 @@ class MS_CORE_API AvgPool : public PrimitiveC {
|
||||||
RoundMode get_round_mode() const;
|
RoundMode get_round_mode() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -180,6 +181,7 @@ TypePtr AvgPool3DInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(AvgPool3D, PrimitiveC, BaseOperator);
|
||||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return abstract::MakeAbstract(AvgPool3DInferShape(primitive, input_args), AvgPool3DInferType(primitive, input_args));
|
return abstract::MakeAbstract(AvgPool3DInferShape(primitive, input_args), AvgPool3DInferType(primitive, input_args));
|
||||||
|
|
|
@ -21,24 +21,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
/// \brief 3D Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool3D for more details.
|
/// \brief 3D Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPool3D for more details.
|
||||||
class MS_CORE_API AvgPool3D : public PrimitiveC {
|
class MIND_API AvgPool3D : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(AvgPool3D);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
AvgPool3D() : PrimitiveC(prim::kPrimAvgPool3D->name()) { InitIOName({"input"}, {"output"}); }
|
AvgPool3D() : BaseOperator("AvgPool3D") { InitIOName({"input"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~AvgPool3D() = default;
|
|
||||||
MS_DECLARE_PARENT(AvgPool3D, PrimitiveC);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -15,13 +15,18 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
MIND_API_BASE_IMPL(BaseOperator, PrimitiveC, api::Primitive);
|
||||||
BaseOperator::BaseOperator(const std::string &name) : api::Primitive(std::make_shared<PrimitiveC>(name)) {}
|
BaseOperator::BaseOperator(const std::string &name) : api::Primitive(std::make_shared<PrimitiveC>(name)) {}
|
||||||
|
|
||||||
|
PrimitiveCPtr BaseOperator::GetPrim() {
|
||||||
|
PrimitiveCPtr res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
void BaseOperator::InitIOName(const std::vector<std::string> &inputs_name,
|
void BaseOperator::InitIOName(const std::vector<std::string> &inputs_name,
|
||||||
const std::vector<std::string> &outputs_name) {
|
const std::vector<std::string> &outputs_name) {
|
||||||
(void)AddAttr("input_names", api::MakeValue(inputs_name));
|
(void)AddAttr("input_names", api::MakeValue(inputs_name));
|
||||||
|
|
|
@ -17,26 +17,36 @@
|
||||||
#ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
#ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
||||||
#define MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
#define MINDSPORE_CORE_OPS_BASE_OPERATOR_
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mindapi/ir/primitive.h"
|
#include "mindapi/ir/primitive.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
class AnalysisEngine;
|
class AnalysisEngine;
|
||||||
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
|
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
|
||||||
|
|
||||||
class AbstractBase;
|
class AbstractBase;
|
||||||
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
|
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class Primitive;
|
||||||
|
using PrimitivePtr = std::shared_ptr<Primitive>;
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class BaseOperator : public api::Primitive {
|
class PrimitiveC;
|
||||||
|
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
|
||||||
|
class MIND_API BaseOperator : public api::Primitive {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(BaseOperator);
|
||||||
explicit BaseOperator(const std::string &name);
|
explicit BaseOperator(const std::string &name);
|
||||||
~BaseOperator() = default;
|
PrimitiveCPtr GetPrim();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
|
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -134,14 +135,15 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(BatchMatmul, PrimitiveC, BaseOperator);
|
||||||
void BatchMatmul::Init(bool transpose_a, bool transpose_b) {
|
void BatchMatmul::Init(bool transpose_a, bool transpose_b) {
|
||||||
set_transpose_a(transpose_a);
|
set_transpose_a(transpose_a);
|
||||||
set_transpose_b(transpose_b);
|
set_transpose_b(transpose_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, api::MakeValue(transpose_a)); }
|
||||||
|
|
||||||
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, api::MakeValue(transpose_b)); }
|
||||||
|
|
||||||
bool BatchMatmul::get_transpose_a() const {
|
bool BatchMatmul::get_transpose_a() const {
|
||||||
auto value_ptr = GetAttr(kTransposeA);
|
auto value_ptr = GetAttr(kTransposeA);
|
||||||
|
|
|
@ -18,21 +18,18 @@
|
||||||
#define MINDSPORE_CORE_OPS_BATCH_MATMUL_H_
|
#define MINDSPORE_CORE_OPS_BATCH_MATMUL_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/base_operator.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "mindapi/base/types.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
/// \brief Computes matrix multiplication between two tensors by batch.
|
/// \brief Computes matrix multiplication between two tensors by batch.
|
||||||
/// Refer to Python API @ref mindspore.ops.BatchMatmul for more details.
|
/// Refer to Python API @ref mindspore.ops.BatchMatmul for more details.
|
||||||
class MS_CORE_API BatchMatmul : public PrimitiveC {
|
class MIND_API BatchMatmul : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(BatchMatmul);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
BatchMatmul() : PrimitiveC(prim::kPrimBatchMatMul->name()) { InitIOName({"x1", "x2"}, {"output"}); }
|
BatchMatmul() : BaseOperator("BatchMatMul") { InitIOName({"x1", "x2"}, {"output"}); }
|
||||||
/// \brief Destructor.
|
|
||||||
~BatchMatmul() = default;
|
|
||||||
MS_DECLARE_PARENT(BatchMatmul, PrimitiveC);
|
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.BatchMatmul for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.BatchMatmul for the inputs.
|
||||||
void Init(bool transpose_a = false, bool transpose_b = false);
|
void Init(bool transpose_a = false, bool transpose_b = false);
|
||||||
/// \brief Set transpose_a.
|
/// \brief Set transpose_a.
|
||||||
|
@ -48,8 +45,8 @@ class MS_CORE_API BatchMatmul : public PrimitiveC {
|
||||||
/// \return transpose_b.
|
/// \return transpose_b.
|
||||||
bool get_transpose_b() const;
|
bool get_transpose_b() const;
|
||||||
};
|
};
|
||||||
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue