add cos infer

This commit is contained in:
jjj 2021-08-24 09:01:35 +08:00
parent ef195c9de6
commit 22dd2c0373
4 changed files with 18 additions and 32 deletions

View File

@ -60,6 +60,7 @@ constexpr auto kAdd = "Add";
constexpr auto kBiasAdd = "BiasAdd";
constexpr auto kTile = "Tile";
constexpr auto kBiasAddGrad = "BiasAddGrad";
constexpr auto kCos = "Cos";
// Arrays
constexpr auto kDynamicShape = "DynamicShape";
@ -457,7 +458,7 @@ inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMi
inline const PrimitivePtr kPrimCentralization = std::make_shared<Primitive>("Centralization");
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>(kNeg);
inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin");
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos");
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>(kCos);
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>(kSub);
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>(kMul);
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,30 +22,26 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
abstract::ShapePtr CosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
TypePtr CosInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types, prim->name());
return x_dtype;
}
} // namespace
AbstractBasePtr CosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(CosInferShape(primitive, input_args), CosInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameCos, Cos);
} // namespace
REGISTER_PRIMITIVE_EVAL_IMPL(Cos, prim::kPrimCos, CosInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,17 +25,13 @@
namespace mindspore {
namespace ops {
constexpr auto kNameCos = "Cos";
class MS_CORE_API Cos : public PrimitiveC {
public:
Cos() : PrimitiveC(kNameCos) {}
Cos() : PrimitiveC(prim::kPrimCos->name()) {}
~Cos() = default;
MS_DECLARE_PARENT(Cos, PrimitiveC);
void Init(float alpha = 0.0);
};
AbstractBasePtr CosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCos = std::shared_ptr<Cos>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_COS_H_

View File

@ -4083,7 +4083,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
return mstype.float32
class Cos(PrimitiveWithInfer):
class Cos(Primitive):
r"""
Computes cosine of input element-wise.
@ -4115,13 +4115,6 @@ class Cos(PrimitiveWithInfer):
def __init__(self):
"""Initialize Cos"""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
class ACos(PrimitiveWithInfer):
r"""