!46951 新增scalar ops原语

Merge pull request !46951 from huoxinyou/1218scalarop
This commit is contained in:
i-robot 2023-02-01 01:46:11 +00:00 committed by Gitee
commit 3c6592e8b1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
60 changed files with 1601 additions and 156 deletions

View File

@ -30,13 +30,13 @@ PrimToFunction::PrimToFunction()
{kScalarUsub, kPrimTypeNumOneArg}, {kScalarAdd, kPrimTypeNumTwoArgs},
{"bool_and", kPrimTypeNumTwoArgs}, {"bool_eq", kPrimTypeNumTwoArgs},
{"bool_or", kPrimTypeNumTwoArgs}, {kScalarDiv, kPrimTypeNumTwoArgs},
{"scalar_eq", kPrimTypeNumTwoArgs}, {"scalar_ge", kPrimTypeNumTwoArgs},
{"scalar_gt", kPrimTypeNumTwoArgs}, {"scalar_le", kPrimTypeNumTwoArgs},
{"scalar_lt", kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
{kScalarEq, kPrimTypeNumTwoArgs}, {kScalarGe, kPrimTypeNumTwoArgs},
{kScalarGt, kPrimTypeNumTwoArgs}, {kScalarLe, kPrimTypeNumTwoArgs},
{kScalarLt, kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
{kScalarMod, kPrimTypeNumTwoArgs}, {kScalarMul, kPrimTypeNumTwoArgs},
{kScalarPow, kPrimTypeNumTwoArgs}, {kScalarSub, kPrimTypeNumTwoArgs},
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {"bit_and", kPrimTypeNumTwoArgs},
{"bit_or", kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
{kScalarBitwiseOr, kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
{"bit_left_shift", kPrimTypeNumTwoArgs}, {"bit_right_shift", kPrimTypeNumTwoArgs},
{kStringNot, kPrimTypeStrOneArg}, {kStringConcat, kPrimTypeStrTwoArgs},
{kStringIn, kPrimTypeStrTwoArgs}, {kStringEq, kPrimTypeStrTwoArgs},

View File

@ -3063,27 +3063,15 @@ using PrimitiveToImplMap = mindspore::HashMap<PrimitivePtr, PrimitiveImplInferVa
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
using R = PrimitiveToImplMap::mapped_type;
static PrimitiveToImplMap uniform_prim_implement_map{
{prim::kPrimScalarAdd, R{prim::ScalarAdd, true, nullptr, true}},
{prim::kPrimScalarSub, R{prim::ScalarSub, true, nullptr, true}},
{prim::kPrimScalarMul, R{prim::ScalarMul, true, nullptr, true}},
{prim::kPrimScalarDiv, R{prim::ScalarDiv, true, nullptr, true}},
{prim::kPrimScalarMod, R{prim::ScalarMod, true, nullptr, true}},
{prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
{prim::kPrimScalarFloordiv, R{prim::ScalarFloordiv, true, nullptr, true}},
{prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
{prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
{prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
{prim::kPrimBitAnd, R{prim::BitAnd, true, nullptr, true}},
{prim::kPrimBitOr, R{prim::BitOr, true, nullptr, true}},
{prim::kPrimBitXor, R{prim::BitXor, true, nullptr, true}},
{prim::kPrimBitLeftShift, R{prim::BitLeftShift, true, nullptr, true}},
{prim::kPrimBitRightShift, R{prim::BitRightShift, true, nullptr, true}},
{prim::kPrimScalarEq, R{prim::ScalarEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLt, R{prim::ScalarLt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGt, R{prim::ScalarGt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLe, R{prim::ScalarLe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGe, R{prim::ScalarGe, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolNot, R{prim::BoolNot, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},

View File

@ -76,6 +76,14 @@ constexpr auto kScalarTrunc = "ScalarTrunc";
constexpr auto kScalarFloor = "ScalarFloor";
constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kScalarEq = "ScalarEqual";
constexpr auto kScalarLt = "ScalarLess";
constexpr auto kScalarGt = "ScalarGreater";
constexpr auto kScalarLe = "ScalarLessEqual";
constexpr auto kScalarGe = "ScalarGreaterEqual";
constexpr auto kScalarBool = "ScalarBool";
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
constexpr auto kExp = "Exp";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
@ -505,12 +513,15 @@ GVAR_DEF(PrimitivePtr, kPrimStringMul, std::make_shared<Primitive>(kStringMul));
GVAR_DEF(PrimitivePtr, kPrimStringGetItem, std::make_shared<Primitive>(kStringGetItem));
// Comparisons
GVAR_DEF(PrimitivePtr, kPrimScalarEq, std::make_shared<Primitive>("scalar_eq"));
GVAR_DEF(PrimitivePtr, kPrimScalarLt, std::make_shared<Primitive>("scalar_lt"));
GVAR_DEF(PrimitivePtr, kPrimScalarGt, std::make_shared<Primitive>("scalar_gt"));
GVAR_DEF(PrimitivePtr, kPrimScalarEq, std::make_shared<Primitive>(kScalarEq));
GVAR_DEF(PrimitivePtr, kPrimScalarLt, std::make_shared<Primitive>(kScalarLt));
GVAR_DEF(PrimitivePtr, kPrimScalarGt, std::make_shared<Primitive>(kScalarGt));
GVAR_DEF(PrimitivePtr, kPrimScalarNe, std::make_shared<Primitive>("scalar_ne"));
GVAR_DEF(PrimitivePtr, kPrimScalarLe, std::make_shared<Primitive>("scalar_le"));
GVAR_DEF(PrimitivePtr, kPrimScalarGe, std::make_shared<Primitive>("scalar_ge"));
GVAR_DEF(PrimitivePtr, kPrimScalarLe, std::make_shared<Primitive>(kScalarLe));
GVAR_DEF(PrimitivePtr, kPrimScalarGe, std::make_shared<Primitive>(kScalarGe));
GVAR_DEF(PrimitivePtr, kPrimScalarBool, std::make_shared<Primitive>(kScalarBool));
GVAR_DEF(PrimitivePtr, kPrimScalarBitwiseAnd, std::make_shared<Primitive>(kScalarBitwiseAnd));
GVAR_DEF(PrimitivePtr, kPrimScalarBitwiseOr, std::make_shared<Primitive>(kScalarBitwiseOr));
GVAR_DEF(PrimitivePtr, kPrimBoolNot, std::make_shared<Primitive>("bool_not"));
GVAR_DEF(PrimitivePtr, kPrimBoolAnd, std::make_shared<Primitive>("bool_and"));
GVAR_DEF(PrimitivePtr, kPrimBoolOr, std::make_shared<Primitive>("bool_or"));
@ -1503,9 +1514,7 @@ GVAR_DEF(PrimitivePtr, kPrimTileShape, std::make_shared<Primitive>("tile_shape")
GVAR_DEF(PrimitivePtr, kPrimGenerateShapeIndex, std::make_shared<Primitive>("generate_shape_index"));
GVAR_DEF(PrimitivePtr, kPrimGenerateInverseIndex, std::make_shared<Primitive>("generate_inverse_index"));
GVAR_DEF(PrimitivePtr, kPrimRealMakeList, std::make_shared<Primitive>(kRealMakeList));
GVAR_DEF(PrimitivePtr, kPrimRealTupleGetItem, std::make_shared<Primitive>(kRealTupleGetItem));
GVAR_DEF(PrimitivePtr, kPrimRealListGetItem, std::make_shared<Primitive>(kRealListGetItem));
GVAR_DEF(PrimitivePtr, kPrimListToTensor, std::make_shared<Primitive>(kListToTensor));
GVAR_DEF(PrimitivePtr, kPrimScalarToTensor, std::make_shared<Primitive>(kScalarToTensor));
GVAR_DEF(PrimitivePtr, kPrimTensorToTuple, std::make_shared<Primitive>(kTensorToTuple));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -20,7 +20,6 @@
#include <vector>
#include "ops/op_utils.h"
#include "ops/real_makelist.h"
#include "abstract/ops/op_infer.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
@ -29,7 +28,6 @@
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MakeList, BaseOperator);
MIND_API_OPERATOR_IMPL(RealMakeList, BaseOperator);
AbstractBasePtr MakeListInnerInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractList>(input_args);
}
@ -51,6 +49,5 @@ class MakeListInfer : public abstract::OpInferBase {
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(MakeList, prim::kPrimMakeList, MakeListInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealMakeList, prim::kPrimRealMakeList, MakeListInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -15,6 +15,7 @@
*/
#include <string>
#include <map>
#include <set>
#include <vector>
#include <algorithm>
@ -682,6 +683,57 @@ template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractList>(const Pri
template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractTuple>(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
template <typename T>
T GetScalarValue(const std::string &op_name, const ValuePtr &elem) {
T res;
MS_EXCEPTION_IF_NULL(elem);
if (elem->isa<Int64Imm>()) {
auto elem_value = GetValue<int64_t>(elem);
res = static_cast<T>(elem_value);
} else if (elem->isa<Int32Imm>()) {
auto elem_value = GetValue<int32_t>(elem);
res = static_cast<T>(elem_value);
} else if (elem->isa<FP64Imm>()) {
auto elem_value = GetValue<double>(elem);
res = static_cast<T>(elem_value);
} else if (elem->isa<FP32Imm>()) {
auto elem_value = GetValue<float>(elem);
res = static_cast<T>(elem_value);
} else if (elem->isa<BoolImm>()) {
auto elem_value = GetValue<bool>(elem);
res = static_cast<T>(elem_value);
} else {
MS_EXCEPTION(TypeError) << "For op '" << op_name
<< "' input must be [int32, int64, float32, float64, bool], but got " << elem->ToString();
}
return res;
}
template int64_t GetScalarValue(const std::string &op_name, const ValuePtr &elem);
template int32_t GetScalarValue(const std::string &op_name, const ValuePtr &elem);
template double GetScalarValue(const std::string &op_name, const ValuePtr &elem);
template float GetScalarValue(const std::string &op_name, const ValuePtr &elem);
template bool GetScalarValue(const std::string &op_name, const ValuePtr &elem);
TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name) {
static std::map<TypeId, size_t> prio_map = {{kNumberTypeFloat64, 1},
{kNumberTypeFloat32, 2},
{kNumberTypeInt64, 3},
{kNumberTypeInt32, 4},
{kNumberTypeBool, 5}};
auto x_iter = prio_map.find(x_type->type_id());
auto y_iter = prio_map.find(y_type->type_id());
if (x_iter == prio_map.end() || y_iter == prio_map.end()) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the x and y type should be int or float, but got x type: " << x_type
<< " y type: " << y_type;
}
if (x_iter->second < y_iter->second) {
return x_type;
}
return y_type;
}
bool IsValueKnown(const ValuePtr &value) {
// For now if the Abstract is a container of elements such as AbstractSequence and AbstractDictionary,
// the BuildValue returns AnyValue if any one of the elements' value is AnyValue

View File

@ -112,6 +112,11 @@ std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const Abstract
template <typename T>
AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
template <typename T>
T GetScalarValue(const std::string &op_name, const ValuePtr &elem);
TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name);
bool IsValueKnown(const ValuePtr &value);
constexpr auto kCSRAvgRows = "csr_avg_rows";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,21 +14,23 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
#define MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
#ifndef MINDSPORE_CORE_OPS_SCALAR_ADD_H_
#define MINDSPORE_CORE_OPS_SCALAR_ADD_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief RealListGetItem op is used to get list[index] value, list is a dynamic length list or index is variable
class MIND_API RealListGetItem : public BaseOperator {
/// \brief ScalarAdd op is used to add between variable scalar.
class MIND_API ScalarAdd : public BaseOperator {
public:
MIND_API_BASE_MEMBER(RealListGetItem);
MIND_API_BASE_MEMBER(ScalarAdd);
/// \brief Constructor.
RealListGetItem() : BaseOperator(prim::kRealListGetItem) { InitIOName({"input", "index"}, {"output"}); }
ScalarAdd() : BaseOperator(prim::kScalarAdd) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
#endif // MINDSPORE_CORE_OPS_SCALAR_ADD_H_

View File

@ -0,0 +1,325 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <memory>
#include "ops/op_utils.h"
#include "abstract/ops/op_infer.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
#include "ops/scalar_add.h"
#include "ops/scalar_sub.h"
#include "ops/scalar_mul.h"
#include "ops/scalar_div.h"
#include "ops/scalar_mod.h"
#include "ops/scalar_eq.h"
#include "ops/scalar_lt.h"
#include "ops/scalar_gt.h"
#include "ops/scalar_le.h"
#include "ops/scalar_ge.h"
namespace mindspore {
namespace ops {
template <typename T>
ValuePtr AddImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
#ifndef _MSC_VER
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_add_overflow(x, y, &res)) {
MS_EXCEPTION(ValueError) << "For prim '" << op_name
<< "' Overflow of the sum of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return MakeValue(res);
}
#endif
return MakeValue(x + y);
}
template <typename T>
ValuePtr SubImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
#ifndef _MSC_VER
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_sub_overflow(x, y, &res)) {
MS_EXCEPTION(ValueError) << "For prim '" << op_name
<< "' Overflow of the sub of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return MakeValue(res);
}
#endif
return MakeValue(x - y);
}
template <typename T>
ValuePtr MulImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
#ifndef _MSC_VER
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_mul_overflow(x, y, &res)) {
MS_EXCEPTION(ValueError) << "For prim '" << op_name
<< "' Overflow of the mul of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return MakeValue(res);
}
#endif
return MakeValue(x * y);
}
template <typename T>
ValuePtr DivImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
T zero = 0;
if (y == zero) {
MS_EXCEPTION(ValueError) << "The divisor could not be zero. But the divisor is zero now.";
}
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
MS_EXCEPTION(ValueError) << "For prim '" << op_name
<< "' Overflow of the div of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
}
return MakeValue(static_cast<float>(x) / static_cast<float>(y));
}
template <typename T>
ValuePtr ModImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
T zero = 0;
if (y == zero) {
MS_EXCEPTION(ValueError) << "Cannot perform modulo operation on zero.";
}
if constexpr (std::is_signed<T>::value) {
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
MS_EXCEPTION(ValueError) << "For prim '" << op_name
<< "' Overflow of the mod of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
}
T n = std::floor(static_cast<float>(x) / static_cast<float>(y));
T res = x - n * y;
return MakeValue(res);
}
template <typename T>
ValuePtr EqImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
if (std::isinf(static_cast<double>(x)) && std::isinf(static_cast<double>(y))) {
return MakeValue((x > 0 && y > 0) || (x < 0 && y < 0));
}
double error = static_cast<double>(x) - static_cast<double>(y);
error = fabs(error);
return MakeValue(error < DBL_EPSILON);
}
template <typename T>
ValuePtr LtImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
return MakeValue(x < y);
}
template <typename T>
ValuePtr GtImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
return MakeValue(x > y);
}
template <typename T>
ValuePtr LeImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
return MakeValue(x <= y);
}
template <typename T>
ValuePtr GeImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
return MakeValue(x >= y);
}
using MathImplFunc = std::function<ValuePtr(const ValuePtr &, const ValuePtr &, const std::string &)>;
template <typename T>
MathImplFunc ChooseFunc(const std::string &prim_name) {
std::map<std::string, MathImplFunc> infer_value_func_map = {
{prim::kScalarAdd, AddImpl<T>}, {prim::kScalarSub, SubImpl<T>}, {prim::kScalarMul, MulImpl<T>},
{prim::kScalarDiv, DivImpl<T>}, {prim::kScalarMod, ModImpl<T>}, {prim::kScalarEq, EqImpl<T>},
{prim::kScalarGt, GtImpl<T>}, {prim::kScalarLt, LtImpl<T>}, {prim::kScalarGe, GeImpl<T>},
{prim::kScalarLe, LeImpl<T>}};
auto iter = infer_value_func_map.find(prim_name);
if (iter == infer_value_func_map.end()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "' don't support. Only support [Add, Sub, Mul, Div, Mod, Eq, Le, Ge, Lt, Gt]";
}
return iter->second;
}
class ScalarArithmeticInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
constexpr size_t input_len = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
auto elem_x = input_args[0];
auto elem_y = input_args[kIndex1];
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
<< " and y: " << elem_y->ToString();
}
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_type = input_args[0]->BuildType();
auto y_type = input_args[kIndex1]->BuildType();
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64};
std::set<std::string> compare_ops = {prim::kScalarEq, prim::kScalarGe, prim::kScalarGt, prim::kScalarLt,
prim::kScalarLe};
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
(void)CheckAndConvertUtils::CheckSubClass("y_dtype", y_type, check_types, prim_name);
auto iter = compare_ops.find(prim_name);
if (prim_name == prim::kScalarDiv) {
return kFloat32;
}
if (iter != compare_ops.end()) {
return kBool;
}
return HighPriorityType(x_type, y_type, prim_name);
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_num = 2;
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t x_index = 0;
constexpr size_t y_index = 1;
auto elem_x = input_args[x_index];
auto elem_y = input_args[y_index];
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
<< " and y: " << elem_y->ToString();
}
auto x_value = elem_x->BuildValue();
auto y_value = elem_y->BuildValue();
if (x_value == kAnyValue || y_value == kAnyValue) {
return nullptr;
}
auto x_type = input_args[x_index]->BuildType();
auto y_type = input_args[y_index]->BuildType();
auto res_type = HighPriorityType(x_type, y_type, op_name);
ValuePtr result;
switch (res_type->type_id()) {
case kNumberTypeInt32: {
auto func = ChooseFunc<int32_t>(op_name);
result = func(x_value, y_value, op_name);
break;
}
case kNumberTypeInt64: {
auto func = ChooseFunc<int64_t>(op_name);
result = func(x_value, y_value, op_name);
break;
}
case kNumberTypeFloat32: {
auto func = ChooseFunc<float>(op_name);
result = func(x_value, y_value, op_name);
break;
}
case kNumberTypeFloat64: {
auto func = ChooseFunc<double>(op_name);
result = func(x_value, y_value, op_name);
break;
}
default: {
MS_EXCEPTION(TypeError) << "For '" << op_name
<< "', the supported type is in the list: [int32, int64, float32, float64], but got "
<< res_type->ToString() << ".";
}
}
return result;
}
private:
std::function<ValuePtr(const ValuePtr &, const ValuePtr &, const std::string &)> infer_value_func_;
};
MIND_API_OPERATOR_IMPL(ScalarAdd, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarSub, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarMul, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarDiv, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarMod, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarEqual, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarGreater, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarGreaterEqual, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarLess, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarLessEqual, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarAdd, prim::kPrimScalarAdd, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarSub, prim::kPrimScalarSub, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMul, prim::kPrimScalarMul, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarDiv, prim::kPrimScalarDiv, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMod, prim::kPrimScalarMod, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarEqual, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreater, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreaterEqual, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLess, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLessEqual, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,123 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <memory>
#include "ops/op_utils.h"
#include "abstract/ops/op_infer.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
#include "ops/scalar_bitwise_or.h"
#include "ops/scalar_bitwise_and.h"
namespace mindspore {
namespace ops {
template <typename T>
T BitwiseImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
if (op_name == prim::kScalarBitwiseAnd) {
return x & y;
} else {
return x | y;
}
}
class ScalarBitwiseInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
constexpr size_t input_len = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
auto elem_x = input_args[0];
auto elem_y = input_args[kIndex1];
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
<< " and y: " << elem_y->ToString();
}
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_type = input_args[0]->BuildType();
auto y_type = input_args[kIndex1]->BuildType();
std::set<TypePtr> check_types = {kInt32, kInt64, kBool};
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
(void)CheckAndConvertUtils::CheckSubClass("y_dtype", y_type, check_types, prim_name);
return HighPriorityType(x_type, y_type, prim_name);
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_num = 2;
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t x_index = 0;
constexpr size_t y_index = 1;
auto elem_x = input_args[x_index];
auto elem_y = input_args[y_index];
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
<< " and y: " << elem_y->ToString();
}
auto x_value = elem_x->BuildValue();
auto y_value = elem_y->BuildValue();
if (x_value == kAnyValue || y_value == kAnyValue) {
return nullptr;
}
auto res_type = InferType(primitive, input_args);
ValuePtr res;
switch (res_type->type_id()) {
case kNumberTypeInt32: {
res = MakeValue(BitwiseImpl<int32_t>(x_value, y_value, op_name));
break;
}
case kNumberTypeInt64: {
res = MakeValue(BitwiseImpl<int64_t>(x_value, y_value, op_name));
break;
}
case kNumberTypeBool: {
res = MakeValue(BitwiseImpl<bool>(x_value, y_value, op_name));
break;
}
default: {
MS_EXCEPTION(TypeError) << "For '" << op_name
<< "', the supported type is in the list: [int32, int64, bool], but got "
<< res_type->ToString() << ".";
}
}
return res;
}
};
MIND_API_OPERATOR_IMPL(ScalarBitwiseOr, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarBitwiseAnd, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseOr, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseAnd, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_
#define MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief
class MIND_API ScalarBitwiseAnd : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarBitwiseAnd);
/// \brief Constructor.
ScalarBitwiseAnd() : BaseOperator(prim::kScalarBitwiseAnd) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,23 +14,23 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
#define MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
#ifndef MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_
#define MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief RealMakeList op
class MIND_API RealMakeList : public BaseOperator {
/// \brief
class MIND_API ScalarBitwiseOr : public BaseOperator {
public:
MIND_API_BASE_MEMBER(RealMakeList);
MIND_API_BASE_MEMBER(ScalarBitwiseOr);
/// \brief Constructor.
RealMakeList() : BaseOperator(prim::kRealMakeList) {}
ScalarBitwiseOr() : BaseOperator(prim::kScalarBitwiseOr) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
#endif // MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_

View File

@ -0,0 +1,110 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <memory>
#include <set>
#include "ops/scalar_bool.h"
#include "ops/op_utils.h"
#include "abstract/ops/op_infer.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
class ScalarBoolInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
constexpr size_t input_len = 1;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
auto elem = input_args[0];
if (!elem->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got : " << elem->ToString();
}
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_type = input_args[0]->BuildType();
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64, kBool};
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
return kBool;
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_num = 1;
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto elem = input_args[0];
if (!elem->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got : " << elem->ToString();
}
auto x_valueptr = elem->BuildValue();
if (x_valueptr == kAnyValue) {
return nullptr;
}
auto x_type = input_args[0]->BuildType();
bool res;
switch (x_type->type_id()) {
case kNumberTypeInt32: {
auto elem_value = GetValue<int32_t>(x_valueptr);
res = static_cast<bool>(elem_value);
break;
}
case kNumberTypeInt64: {
auto elem_value = GetValue<int64_t>(x_valueptr);
res = static_cast<bool>(elem_value);
break;
}
case kNumberTypeFloat32: {
auto elem_value = GetValue<float>(x_valueptr);
res = static_cast<bool>(elem_value);
break;
}
case kNumberTypeFloat64: {
auto elem_value = GetValue<double>(x_valueptr);
res = static_cast<bool>(elem_value);
break;
}
case kNumberTypeBool: {
auto elem_value = GetValue<bool>(x_valueptr);
res = static_cast<bool>(elem_value);
}
default: {
MS_EXCEPTION(TypeError)
<< "For '" << op_name
<< "', the supported type is in the list: [int32, int64, float32, float64, bool], but got "
<< x_type->ToString() << ".";
}
}
return MakeValue(res);
}
};
MIND_API_OPERATOR_IMPL(ScalarBool, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBool, prim::kPrimScalarBool, ScalarBoolInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_BOOL_H_
#define MINDSPORE_CORE_OPS_SCALAR_BOOL_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarBool op is used to calculate the input true or false.
class MIND_API ScalarBool : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarBool);
/// \brief Constructor.
ScalarBool() : BaseOperator(prim::kScalarBool) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_BOOL_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_DIV_H_
#define MINDSPORE_CORE_OPS_SCALAR_DIV_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarDiv op is used to div between variable scalar.
class MIND_API ScalarDiv : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarDiv);
/// \brief Constructor.
ScalarDiv() : BaseOperator(prim::kScalarDiv) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_DIV_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_EQ_H_
#define MINDSPORE_CORE_OPS_SCALAR_EQ_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarEqual op is used to judge equal between variable scalar.
class MIND_API ScalarEqual : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarEqual);
/// \brief Constructor.
ScalarEqual() : BaseOperator(prim::kScalarEq) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_EQ_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_GE_H_
#define MINDSPORE_CORE_OPS_SCALAR_GE_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarGreaterEqual op is used to judge greaterEqual between variable scalar.
class MIND_API ScalarGreaterEqual : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarGreaterEqual);
/// \brief Constructor.
ScalarGreaterEqual() : BaseOperator(prim::kScalarGe) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_GE_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_GT_H_
#define MINDSPORE_CORE_OPS_SCALAR_GT_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarGreater op is used to judge greater between variable scalar.
class MIND_API ScalarGreater : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarGreater);
/// \brief Constructor.
ScalarGreater() : BaseOperator(prim::kScalarGt) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_GT_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_LE_H_
#define MINDSPORE_CORE_OPS_SCALAR_LE_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarLessEqual op is used to judge lessEqual between variable scalar.
class MIND_API ScalarLessEqual : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarLessEqual);
/// \brief Constructor.
ScalarLessEqual() : BaseOperator(prim::kScalarLe) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_LE_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_LT_H_
#define MINDSPORE_CORE_OPS_SCALAR_LT_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarLess op is used to judge less between variable scalar.
class MIND_API ScalarLess : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarLess);
/// \brief Constructor.
ScalarLess() : BaseOperator(prim::kScalarLt) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_LT_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_MOD_H_
#define MINDSPORE_CORE_OPS_SCALAR_MOD_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarMod op is used to mod between variable scalar.
class MIND_API ScalarMod : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarMod);
/// \brief Constructor.
ScalarMod() : BaseOperator(prim::kScalarMod) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_MOD_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_MUL_H_
#define MINDSPORE_CORE_OPS_SCALAR_MUL_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarMul op is used to Mul between variable scalar.
class MIND_API ScalarMul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarMul);
/// \brief Constructor.
ScalarMul() : BaseOperator(prim::kScalarMul) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_Mul_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCALAR_SUB_H_
#define MINDSPORE_CORE_OPS_SCALAR_SUB_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief ScalarSub op is used to add between variable scalar.
class MIND_API ScalarSub : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarSub);
/// \brief Constructor.
ScalarSub() : BaseOperator(prim::kScalarSub) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCALAR_SUB_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -88,10 +88,8 @@ class ScalarToTensorInfer : public abstract::OpInferBase {
MS_EXCEPTION_IF_NULL(attr);
}
if (!attr->isa<Type>()) {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
MS_EXCEPTION(TypeError) << "For '" << prim_name << "the second input must be a `Type`, but got "
<< attr->type_name();
}
auto output_dtype = attr->cast<TypePtr>();
@ -128,5 +126,6 @@ class ScalarToTensorInfer : public abstract::OpInferBase {
}
};
MIND_API_OPERATOR_IMPL(ScalarToTensor, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarToTensor, prim::kPrimScalarToTensor, ScalarToTensorInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@
#include "ops/tuple_get_item.h"
#include "ops/list_getitem.h"
#include "ops/real_tuple_getitem.h"
#include "ops/real_list_getitem.h"
#include "ops/op_utils.h"
#include "abstract/param_validator.h"
#include "abstract/ops/op_infer.h"
@ -104,10 +103,8 @@ class SequenceGetItemInfer : public abstract::OpInferBase {
MIND_API_OPERATOR_IMPL(TupleGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(RealTupleGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(ListGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(RealListGetItem, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(TupleGetItem, prim::kPrimTupleGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealTupleGetItem, prim::kPrimRealTupleGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ListGetItem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealListGetItem, prim::kPrimRealListGetItem, SequenceGetItemInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -65,7 +65,7 @@ def TupleGetItem(x, index):
return x[index]
def scalar_gt(x, y):
def ScalarGreater(x, y):
"""Implement `scalar_gt`."""
return x > y
@ -75,17 +75,17 @@ def scalar_ne(x, y):
return x != y
def scalar_eq(x, y):
def ScalarEqual(x, y):
"""Implement `scalar_eq`."""
return x == y
def scalar_le(x, y):
def ScalarLessEqual(x, y):
"""Implement `scalar_le`."""
return x <= y
def scalar_lt(x, y):
def ScalarLess(x, y):
"""Implement `scalar_lt`."""
return x < y

View File

@ -114,7 +114,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
if parameter_divisibility:
delta_w = P.Squeeze()(delta_w)
ori_len = F.shape(delta_w)[0]
divide_len = ori_len / 2
divide_len = ori_len // 2
left_part = delta_w[:divide_len]
right_part = delta_w[divide_len:]
else:

View File

@ -18,12 +18,7 @@
""" Define constants"""
# Arithmetic
kScalarAdd = "ScalarAdd"
kScalarSub = "ScalarSub"
kScalarMul = "ScalarMul"
kScalarDiv = "ScalarDiv"
kScalarFloordiv = "ScalarFloordiv"
kScalarMod = "ScalarMod"
kScalarPow = "ScalarPow"
kScalarTrunc = "ScalarTrunc"
kScalarFloor = "ScalarFloor"

View File

@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore.ops.composite import multitype_ops as C
from mindspore.ops._grad.grad_base import bprops
from mindspore.common import dtype as mstype
from mindspore.ops.operations import _scalar_ops
get_dtype = P.DType()
# Unused parameters are placeholders.
@ -35,25 +36,25 @@ def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
return F.zeros_like(x), F.zeros_like(y), dz
@bprops.register(_constants.kScalarAdd)
@bprops.register(_scalar_ops.ScalarAdd)
def bprop_scalar_add(x, y, out, dout):
"""Backpropagator for primitive `scalar_add`."""
return dout, dout
@bprops.register(_constants.kScalarMul)
@bprops.register(_scalar_ops.ScalarMul)
def bprop_scalar_mul(x, y, out, dout):
"""Backpropagator for primitive `scalar_mul`."""
return dout * y, dout * x
@bprops.register(_constants.kScalarSub)
@bprops.register(_scalar_ops.ScalarSub)
def bprop_scalar_sub(x, y, out, dout):
"""Backpropagator for primitive `scalar_sub`."""
return dout, -dout
@bprops.register(_constants.kScalarDiv)
@bprops.register(_scalar_ops.ScalarDiv)
def bprop_scalar_div(x, y, out, dout):
"""Backpropagator for primitive `scalar_div`."""
return dout / y, (-dout) * (out / y)
@ -187,16 +188,16 @@ def bprop_mutable(x, out, dout):
return (dout,)
@bprops.register("scalar_gt")
@bprops.register("scalar_lt")
@bprops.register("scalar_ge")
@bprops.register("scalar_le")
@bprops.register("scalar_eq")
@bprops.register(_scalar_ops.ScalarGreater)
@bprops.register(_scalar_ops.ScalarLess)
@bprops.register(_scalar_ops.ScalarGreaterEqual)
@bprops.register(_scalar_ops.ScalarLessEqual)
@bprops.register(_scalar_ops.ScalarEqual)
@bprops.register("scalar_ne")
@bprops.register("bool_and")
@bprops.register("bool_or")
@bprops.register("bit_and")
@bprops.register("bit_or")
@bprops.register(_scalar_ops.ScalarBitwiseAnd)
@bprops.register(_scalar_ops.ScalarBitwiseOr)
@bprops.register("bit_xor")
@bprops.register("bit_left_shift")
@bprops.register("bit_right_shift")

View File

@ -22,7 +22,7 @@ from mindspore.ops.function import *
from mindspore.ops.function.array_func import narrow
from mindspore.ops import operations as P
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops
from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops, _scalar_ops
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
@ -55,6 +55,17 @@ partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
identity = P.identity()
# tuple/list/scalar ops
scalar_div = _scalar_ops.ScalarDiv()
scalar_mod = _scalar_ops.ScalarMod()
scalar_add = _scalar_ops.ScalarAdd()
scalar_mul = _scalar_ops.ScalarMul()
scalar_sub = _scalar_ops.ScalarSub()
scalar_gt = _scalar_ops.ScalarGreater()
scalar_ge = _scalar_ops.ScalarGreaterEqual()
scalar_le = _scalar_ops.ScalarLessEqual()
scalar_lt = _scalar_ops.ScalarLess()
scalar_eq = _scalar_ops.ScalarEqual()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive(_constants.kTupleGetItem)
@ -73,22 +84,12 @@ make_list = Primitive('make_list')
make_slice = Primitive('make_slice')
tuple_equal = Primitive("tuple_equal")
list_equal = Primitive("list_equal")
scalar_add = Primitive(_constants.kScalarAdd)
scalar_mul = Primitive(_constants.kScalarMul)
scalar_sub = Primitive(_constants.kScalarSub)
scalar_div = Primitive(_constants.kScalarDiv)
scalar_floordiv = Primitive(_constants.kScalarFloordiv)
scalar_log = Primitive('scalar_log')
scalar_pow = Primitive(_constants.kScalarPow)
scalar_gt = Primitive('scalar_gt')
scalar_ge = Primitive('scalar_ge')
scalar_le = Primitive('scalar_le')
scalar_lt = Primitive('scalar_lt')
scalar_eq = Primitive('scalar_eq')
scalar_ne = Primitive('scalar_ne')
scalar_uadd = Primitive(_constants.kScalarUadd)
scalar_usub = Primitive(_constants.kScalarUsub)
scalar_mod = Primitive(_constants.kScalarMod)
string_eq = Primitive('string_eq')
string_concat = Primitive('string_concat')
bool_not = Primitive("bool_not")

View File

@ -21,6 +21,7 @@ import numpy as np
from mindspore.common import Tensor
from mindspore.ops import composite as C
from mindspore.ops.operations.array_ops import Cast
from mindspore.ops.operations._scalar_ops import ScalarBitwiseOr, ScalarBitwiseAnd
from mindspore.ops import signature as sig
from mindspore.ops.operations.math_ops import _infer_shape_reduce
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
@ -35,8 +36,8 @@ from mindspore.common._register_for_adapter import ms_adapter_registry
# Bit operation
bit_and = Primitive("bit_and")
bit_or = Primitive("bit_or")
bit_and = ScalarBitwiseAnd()
bit_or = ScalarBitwiseOr()
bit_xor = Primitive("bit_xor")
bit_left_shift = Primitive("bit_left_shift")
bit_right_shift = Primitive("bit_right_shift")

View File

@ -0,0 +1,367 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operations for sequence"""
from mindspore.ops.primitive import Primitive, prim_attr_register
class ScalarDiv(Primitive):
r"""
Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
.. math::
out_{i} = \frac{x_i}{y_i}
.. note::
The inputs can be constant/variable value. Usage is the same as '/' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is float.
Raises:
TypeError: If `x` and `y` are not scalar.
ValueError: If `y` is 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarDiv"""
class ScalarAdd(Primitive):
r"""
Adds two input scalar.
.. note::
The inputs can be constant/variable value. Usage is the same as '+' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarAdd"""
class ScalarSub(Primitive):
r"""
Subtracts the second input Scalar from the first input Scalar.
.. note::
The inputs can be constant/variable value. Usage is the same as '-' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarSub"""
class ScalarMul(Primitive):
r"""
Muls two input scalar.
.. note::
The inputs can be constant/variable value. Usage is the same as '+' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarMul"""
class ScalarEqual(Primitive):
r"""
Computes the equivalence between two Scalars.
.. note::
The inputs can be constant/variable value. Usage is the same as '==' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is bool.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarMul"""
class ScalarGreater(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
.. note::
The inputs can be constant/variable value. Usage is the same as '>' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is bool.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarGreater"""
class ScalarLess(Primitive):
r"""
Computes the boolean value of :math:`x < y`.
.. note::
The inputs can be constant/variable value. Usage is the same as '<' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is bool.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarLess"""
class ScalarGreaterEqual(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
.. note::
The inputs can be constant/variable value. Usage is the same as '>=' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is bool.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarGreaterEqual"""
class ScalarLessEqual(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
.. note::
The inputs can be constant/variable value. Usage is the same as '<=' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type of scalar is bool.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarLessEqual"""
class ScalarMod(Primitive):
r"""
Computes the remainder of dividing the first input scalar by the second input scalar element-wise.
.. math::
out_{i} = x_{i} \text{ % } y_{i}
.. note::
The inputs can be constant/variable value. Usage is the same as '%' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
- **y** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarMod"""
class ScalarBool(Primitive):
r"""
Computes the input scalar true or false.
.. note::
The inputs can be constant/variable value.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar.
Outputs:
Scalar, the type is bool.
Raises:
TypeError: If `x` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarBool"""
class ScalarBitwiseAnd(Primitive):
r"""
Returns bitwise `and` of two scalars.
.. math::
out_{i} = x_{i} \text{ % } y_{i}
.. note::
The inputs can be constant/variable value. Usage is the same as '%' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar, the type can be int or bool.
- **y** (Scalar) - A constant or variable scalar, the type can be int or bool.
Outputs:
Scalar, the type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarMod"""
class ScalarBitwiseOr(Primitive):
r"""
Returns bitwise `or` of two scalars.
.. note::
The inputs can be constant/variable value. Usage is the same as '|' in Python.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Scalar) - A constant or variable scalar, the type can be int or bool.
- **y** (Scalar) - A constant or variable scalar, the type can be int or bool.
Outputs:
Scalar, the type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` are not scalar.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarMod"""

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -43,7 +43,6 @@ class ListAppend(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize ListAppend"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_data', 'target'], outputs=['output_data'])
@ -75,7 +74,6 @@ class SequenceSlice(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SequenceSlice"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['seq', 'start', 'stop', 'step'], outputs=['output_data'])
@ -107,7 +105,6 @@ class SequenceSliceSetItem(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SequenceSliceSetItem"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['seq', 'target', 'start', 'stop', 'step'], outputs=['output_data'])
@ -136,7 +133,6 @@ class SequenceAdd(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SequenceAdd"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_1', 'input_2'], outputs=['output_data'])
@ -168,7 +164,6 @@ class TupleToTensor(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize TupleToTensor"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_tuple', 'dtype'], outputs=['output_data'])
@ -200,7 +195,6 @@ class ListToTensor(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize ListToTensor"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_list', 'dtype'], outputs=['output_data'])
@ -228,7 +222,6 @@ class TensorToTuple(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize TensorToTuple"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
@ -256,7 +249,6 @@ class TensorToList(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize TensorToList"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
@ -284,7 +276,6 @@ class TensorToScalar(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize TensorToScalar"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
@ -312,7 +303,6 @@ class SequenceCount(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SequenceCount"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['sequence', 'target'], outputs=['output_data'])
@ -341,5 +331,4 @@ class SequenceMul(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SequenceMul"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['sequence', 'scalar'], outputs=['output_data'])

View File

@ -1730,9 +1730,9 @@ class ScalarToTensor(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
pass
self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data'])
def infer_value(self, x, dtype=mstype.float32):
def __call__(self, x, dtype=mstype.float32):
validator.check_value_type("x", x, [int, float], self.name)
validator.check_subclass("dtype", dtype, mstype.number, self.name)
data_type = mstype.dtype_to_nptype(dtype)

View File

@ -146,12 +146,12 @@ class YOLOv3(nn.Cell):
con1, big_object_output = self.backblock0(feature_map3)
con1 = self.conv1(con1)
ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1)
ups1 = P.ResizeNearestNeighbor((img_hight // 16, img_width // 16))(con1)
con1 = self.concat((ups1, feature_map2))
con2, medium_object_output = self.backblock1(con1)
con2 = self.conv2(con2)
ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2)
ups2 = P.ResizeNearestNeighbor((img_hight // 8, img_width // 8))(con2)
con3 = self.concat((ups2, feature_map1))
_, small_object_output = self.backblock2(con3)

View File

@ -112,17 +112,17 @@ TEST_F(TestOps, ScalarTanTest) {
// Comparisons
TEST_F(TestOps, ScalarEqTest) {
auto prim = std::make_shared<Primitive>("scalar_eq");
auto prim = std::make_shared<Primitive>(prim::kScalarEq);
ASSERT_EQ(prim->name(), kPrimScalarEq->name());
}
TEST_F(TestOps, ScalarLtTest) {
auto prim = std::make_shared<Primitive>("scalar_lt");
auto prim = std::make_shared<Primitive>(prim::kScalarLt);
ASSERT_EQ(prim->name(), kPrimScalarLt->name());
}
TEST_F(TestOps, ScalarGtTest) {
auto prim = std::make_shared<Primitive>("scalar_gt");
auto prim = std::make_shared<Primitive>(prim::kScalarGt);
ASSERT_EQ(prim->name(), kPrimScalarGt->name());
}
@ -132,12 +132,12 @@ TEST_F(TestOps, ScalarNeTest) {
}
TEST_F(TestOps, ScalarLeTest) {
auto prim = std::make_shared<Primitive>("scalar_le");
auto prim = std::make_shared<Primitive>(prim::kScalarLe);
ASSERT_EQ(prim->name(), kPrimScalarLe->name());
}
TEST_F(TestOps, ScalarGeTest) {
auto prim = std::make_shared<Primitive>("scalar_ge");
auto prim = std::make_shared<Primitive>(prim::kScalarGe);
ASSERT_EQ(prim->name(), kPrimScalarGe->name());
}

View File

@ -164,7 +164,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(static_cast<int64_t>(3)));
}
class TestInferGraph : public UT::Common {
@ -273,7 +273,7 @@ TEST_F(TestInferGraph, test_inferred) {
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(static_cast<int64_t>(3)));
}
TEST_F(TestInferGraph, test_context) {
@ -352,6 +352,7 @@ void TestInferMetaGraph::TearDown() {
TEST_F(TestInferMetaGraph, test_inferred) {
AbstractBasePtrList args_spec_list;
int64_t v1 = 1;
int64_t res = 2;
std::cout << "Begin TestInferGraph." << std::endl;
std::cout << func_graph_->get_return()->ToString() << std::endl;
AbstractBasePtr abstract_v1 = FromValue(v1, false);
@ -359,7 +360,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(res));
}
class TestInferUniform : public UT::Common {

View File

@ -13,11 +13,10 @@
# limitations under the License.
# ============================================================================
""" Test for GraphCloner """
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import functional as F
scala_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive(Constants.kScalarMul)
scala_add = F.scalar_add
scalar_mul = F.scalar_mul
def test_clone_simple():

View File

@ -17,11 +17,11 @@ import numpy as np
import mindspore as ms
from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import functional as F
from tests.ut.python.model.resnet import resnet50
scala_add = Primitive(Constants.kScalarAdd)
scala_add = F.scalar_add
def scalar_add(x, y):

View File

@ -29,8 +29,8 @@ from mindspore.ops.operations import _grad_ops as G
# pylint: disable=unused-argument
# pylint: disable=redefined-outer-name
scalar_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive(Constants.kScalarMul)
scalar_add = F.scalar_add
scalar_mul = F.scalar_mul
tuple_getitem = Primitive(Constants.kTupleGetItem)
switch = Primitive('Switch')
@ -354,7 +354,7 @@ def test_inline_while(tag):
def test_cse(tag):
""" test_cse """
fns = FnDict()
scalar_div = Primitive(Constants.kScalarDiv)
scalar_div = F.scalar_div
@fns
def test_f1(x, y):
@ -774,7 +774,7 @@ def test_incorporate_getitem(tag):
def test_incorporate_getitem_through_switch(tag):
""" test_incorporate_getitem_through_switch """
fns = FnDict()
scalar_gt = Primitive('scalar_gt')
scalar_gt = F.scalar_gt
@fns
def before(x, y):
@ -834,7 +834,7 @@ def test_incorporate_call_through_switch(tag):
fns = FnDict()
f1 = Primitive('f1')
f2 = Primitive('f2')
scalar_gt = Primitive('scalar_gt')
scalar_gt = F.scalar_gt
identity = Primitive('identity')
@fns
@ -869,7 +869,7 @@ def test_incorporate_call_through_switch(tag):
def test_float_tuple_getitem_through_switch(tag):
""" test_float_tuple_getitem_through_switch """
fns = FnDict()
scalar_gt = Primitive('scalar_gt')
scalar_gt = F.scalar_gt
@fns
def before(x, y):
@ -931,7 +931,7 @@ def test_convert_switch_ops(tag):
fns = FnDict()
ge_switch = Primitive('GeSwitch')
merge = Primitive('Merge')
add = Primitive(Constants.kScalarAdd)
add = F.scalar_add
neg = Primitive('Neg')
tuple_getitem = Primitive(Constants.kTupleGetItem)
make_tuple = Primitive('MakeTuple')

View File

@ -13,12 +13,10 @@
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import Primitive
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
scala_add = Primitive(Constants.kScalarAdd)
scala_add = F.scalar_add
class AddNet(nn.Cell):

View File

@ -13,8 +13,7 @@
# limitations under the License.
# ============================================================================
""" multi_relu_case """
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import functional as F
# Test user define ops
@ -22,7 +21,7 @@ def get_test_ops_fn():
return test_ops_f
scalar_mul = Primitive(Constants.kScalarMul)
scalar_mul = F.scalar_mul
def test_ops_f(x, y):

View File

@ -13,12 +13,11 @@
# limitations under the License.
# ============================================================================
""" vm_test """
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import functional as F
scala_add = Primitive(Constants.kScalarAdd)
scala_mul = Primitive(Constants.kScalarMul)
scalar_gt = Primitive('scalar_gt')
scala_add = F.scalar_add
scala_mul = F.scalar_mul
scalar_gt = F.scalar_gt
def ScalarAdd(x, y):

View File

@ -0,0 +1,96 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mul operation for dynamic sequence and variable integer in graph mode"""
import numpy as np
from mindspore import jit
from mindspore import context
from mindspore.common import mutable
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE)
def test_constant_scalar_div_and_mod():
"""
Feature: Constant scalar div and mod operation.
Description:
Expectation: No exception.
"""
@jit
def foo():
return 2/3, -2%3
ret1, ret2 = foo()
tol = 1e-6
assert np.abs(ret1 - 0.666666) < tol
assert np.abs(ret2 - 1) < tol
def test_constant_scalar_bitwise():
"""
Feature: Constant scalar bitwise operation.
Description:
Expectation: No exception.
"""
@jit
def foo():
return 3 & 1, True & 4, 4 | 3
ret1, ret2, ret3 = foo()
assert ret1 == 1
assert ret2 == 0
assert ret3 == 7
def test_variable_scalar_div_and_mod():
"""
Feature: Variable scalar div and mod operation.
Description:
Expectation: No exception.
"""
@jit
def foo():
x = mutable(2)
y = mutable(3)
ret1 = x / y
ret2 = x % y
return isinstance(ret1, float), F.isconstant(ret2)
ret1, ret2 = foo()
assert ret1
assert not ret2
def test_variable_scalar_bitwise():
"""
Feature: Variable scalar bitwise operation.
Description:
Expectation: No exception.
"""
@jit
def foo():
x = mutable(2)
y = mutable(3)
ret1 = x & y
ret2 = x | y
return isinstance(ret1, int), F.isconstant(ret2)
ret1, ret2 = foo()
assert ret1
assert not ret2

View File

@ -142,9 +142,8 @@ def test_bitwise_operator_error_float_input():
return res
net = Net()
with pytest.raises(TypeError) as err:
with pytest.raises(TypeError):
net()
assert "Unsupported input type. For BitOr, only integer types are supported, but got" in str(err.value)
def test_bitwise_operator_error_too_large_number():

View File

@ -17,8 +17,6 @@ import numpy as np
from mindspore import Tensor
from mindspore.common.api import jit
from mindspore.ops import Primitive
from mindspore.ops import _constants
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -29,7 +27,7 @@ from ...ut_filter import non_graph_engine
tensor_add = P.Add()
scala_add = Primitive(_constants.kScalarAdd)
scala_add = F.scalar_add
add = C.MultitypeFuncGraph('add')

View File

@ -18,16 +18,15 @@ import numpy as np
from mindspore import Tensor
from mindspore.common.api import jit
from mindspore.common.parameter import Parameter
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import _constants
from mindspore.ops import functional as F
from mindspore import dtype as mstype
from ...ut_filter import non_graph_engine
tensor_add = P.Add()
op_add = P.AddN()
scala_add = Primitive(_constants.kScalarAdd)
scala_add = F.scalar_add
add = C.MultitypeFuncGraph('add')