add cos infer
This commit is contained in:
parent
ef195c9de6
commit
22dd2c0373
|
@ -60,6 +60,7 @@ constexpr auto kAdd = "Add";
|
|||
constexpr auto kBiasAdd = "BiasAdd";
|
||||
constexpr auto kTile = "Tile";
|
||||
constexpr auto kBiasAddGrad = "BiasAddGrad";
|
||||
constexpr auto kCos = "Cos";
|
||||
|
||||
// Arrays
|
||||
constexpr auto kDynamicShape = "DynamicShape";
|
||||
|
@ -457,7 +458,7 @@ inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMi
|
|||
inline const PrimitivePtr kPrimCentralization = std::make_shared<Primitive>("Centralization");
|
||||
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>(kNeg);
|
||||
inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin");
|
||||
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos");
|
||||
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>(kCos);
|
||||
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>(kSub);
|
||||
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>(kMul);
|
||||
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,30 +22,26 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
abstract::ShapePtr CosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
TypePtr CosInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types, prim->name());
|
||||
return x_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
|
||||
return abstract::MakeAbstract(CosInferShape(primitive, input_args), CosInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameCos, Cos);
|
||||
} // namespace
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Cos, prim::kPrimCos, CosInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,17 +25,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCos = "Cos";
|
||||
class MS_CORE_API Cos : public PrimitiveC {
|
||||
public:
|
||||
Cos() : PrimitiveC(kNameCos) {}
|
||||
Cos() : PrimitiveC(prim::kPrimCos->name()) {}
|
||||
~Cos() = default;
|
||||
MS_DECLARE_PARENT(Cos, PrimitiveC);
|
||||
void Init(float alpha = 0.0);
|
||||
};
|
||||
AbstractBasePtr CosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCos = std::shared_ptr<Cos>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_COS_H_
|
||||
|
|
|
@ -4083,7 +4083,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
return mstype.float32
|
||||
|
||||
|
||||
class Cos(PrimitiveWithInfer):
|
||||
class Cos(Primitive):
|
||||
r"""
|
||||
Computes cosine of input element-wise.
|
||||
|
||||
|
@ -4115,13 +4115,6 @@ class Cos(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""Initialize Cos"""
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ACos(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue