forked from mindspore-Ecosystem/mindspore
!46951 新增scalar ops原语
Merge pull request !46951 from huoxinyou/1218scalarop
This commit is contained in:
commit
3c6592e8b1
|
@ -30,13 +30,13 @@ PrimToFunction::PrimToFunction()
|
||||||
{kScalarUsub, kPrimTypeNumOneArg}, {kScalarAdd, kPrimTypeNumTwoArgs},
|
{kScalarUsub, kPrimTypeNumOneArg}, {kScalarAdd, kPrimTypeNumTwoArgs},
|
||||||
{"bool_and", kPrimTypeNumTwoArgs}, {"bool_eq", kPrimTypeNumTwoArgs},
|
{"bool_and", kPrimTypeNumTwoArgs}, {"bool_eq", kPrimTypeNumTwoArgs},
|
||||||
{"bool_or", kPrimTypeNumTwoArgs}, {kScalarDiv, kPrimTypeNumTwoArgs},
|
{"bool_or", kPrimTypeNumTwoArgs}, {kScalarDiv, kPrimTypeNumTwoArgs},
|
||||||
{"scalar_eq", kPrimTypeNumTwoArgs}, {"scalar_ge", kPrimTypeNumTwoArgs},
|
{kScalarEq, kPrimTypeNumTwoArgs}, {kScalarGe, kPrimTypeNumTwoArgs},
|
||||||
{"scalar_gt", kPrimTypeNumTwoArgs}, {"scalar_le", kPrimTypeNumTwoArgs},
|
{kScalarGt, kPrimTypeNumTwoArgs}, {kScalarLe, kPrimTypeNumTwoArgs},
|
||||||
{"scalar_lt", kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
|
{kScalarLt, kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
|
||||||
{kScalarMod, kPrimTypeNumTwoArgs}, {kScalarMul, kPrimTypeNumTwoArgs},
|
{kScalarMod, kPrimTypeNumTwoArgs}, {kScalarMul, kPrimTypeNumTwoArgs},
|
||||||
{kScalarPow, kPrimTypeNumTwoArgs}, {kScalarSub, kPrimTypeNumTwoArgs},
|
{kScalarPow, kPrimTypeNumTwoArgs}, {kScalarSub, kPrimTypeNumTwoArgs},
|
||||||
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {"bit_and", kPrimTypeNumTwoArgs},
|
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
|
||||||
{"bit_or", kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
|
{kScalarBitwiseOr, kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
|
||||||
{"bit_left_shift", kPrimTypeNumTwoArgs}, {"bit_right_shift", kPrimTypeNumTwoArgs},
|
{"bit_left_shift", kPrimTypeNumTwoArgs}, {"bit_right_shift", kPrimTypeNumTwoArgs},
|
||||||
{kStringNot, kPrimTypeStrOneArg}, {kStringConcat, kPrimTypeStrTwoArgs},
|
{kStringNot, kPrimTypeStrOneArg}, {kStringConcat, kPrimTypeStrTwoArgs},
|
||||||
{kStringIn, kPrimTypeStrTwoArgs}, {kStringEq, kPrimTypeStrTwoArgs},
|
{kStringIn, kPrimTypeStrTwoArgs}, {kStringEq, kPrimTypeStrTwoArgs},
|
||||||
|
|
|
@ -3063,27 +3063,15 @@ using PrimitiveToImplMap = mindspore::HashMap<PrimitivePtr, PrimitiveImplInferVa
|
||||||
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
|
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
|
||||||
using R = PrimitiveToImplMap::mapped_type;
|
using R = PrimitiveToImplMap::mapped_type;
|
||||||
static PrimitiveToImplMap uniform_prim_implement_map{
|
static PrimitiveToImplMap uniform_prim_implement_map{
|
||||||
{prim::kPrimScalarAdd, R{prim::ScalarAdd, true, nullptr, true}},
|
|
||||||
{prim::kPrimScalarSub, R{prim::ScalarSub, true, nullptr, true}},
|
|
||||||
{prim::kPrimScalarMul, R{prim::ScalarMul, true, nullptr, true}},
|
|
||||||
{prim::kPrimScalarDiv, R{prim::ScalarDiv, true, nullptr, true}},
|
|
||||||
{prim::kPrimScalarMod, R{prim::ScalarMod, true, nullptr, true}},
|
|
||||||
{prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
|
{prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
|
||||||
{prim::kPrimScalarFloordiv, R{prim::ScalarFloordiv, true, nullptr, true}},
|
{prim::kPrimScalarFloordiv, R{prim::ScalarFloordiv, true, nullptr, true}},
|
||||||
{prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
|
{prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
|
||||||
{prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
|
{prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
|
||||||
{prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
|
{prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
|
||||||
{prim::kPrimBitAnd, R{prim::BitAnd, true, nullptr, true}},
|
|
||||||
{prim::kPrimBitOr, R{prim::BitOr, true, nullptr, true}},
|
|
||||||
{prim::kPrimBitXor, R{prim::BitXor, true, nullptr, true}},
|
{prim::kPrimBitXor, R{prim::BitXor, true, nullptr, true}},
|
||||||
{prim::kPrimBitLeftShift, R{prim::BitLeftShift, true, nullptr, true}},
|
{prim::kPrimBitLeftShift, R{prim::BitLeftShift, true, nullptr, true}},
|
||||||
{prim::kPrimBitRightShift, R{prim::BitRightShift, true, nullptr, true}},
|
{prim::kPrimBitRightShift, R{prim::BitRightShift, true, nullptr, true}},
|
||||||
{prim::kPrimScalarEq, R{prim::ScalarEq, true, std::make_shared<Bool>(), true}},
|
|
||||||
{prim::kPrimScalarLt, R{prim::ScalarLt, true, std::make_shared<Bool>(), true}},
|
|
||||||
{prim::kPrimScalarGt, R{prim::ScalarGt, true, std::make_shared<Bool>(), true}},
|
|
||||||
{prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
|
{prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
|
||||||
{prim::kPrimScalarLe, R{prim::ScalarLe, true, std::make_shared<Bool>(), true}},
|
|
||||||
{prim::kPrimScalarGe, R{prim::ScalarGe, true, std::make_shared<Bool>(), true}},
|
|
||||||
{prim::kPrimBoolNot, R{prim::BoolNot, true, std::make_shared<Bool>(), true}},
|
{prim::kPrimBoolNot, R{prim::BoolNot, true, std::make_shared<Bool>(), true}},
|
||||||
{prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
|
{prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
|
||||||
{prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},
|
{prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},
|
||||||
|
|
|
@ -76,6 +76,14 @@ constexpr auto kScalarTrunc = "ScalarTrunc";
|
||||||
constexpr auto kScalarFloor = "ScalarFloor";
|
constexpr auto kScalarFloor = "ScalarFloor";
|
||||||
constexpr auto kScalarUadd = "ScalarUadd";
|
constexpr auto kScalarUadd = "ScalarUadd";
|
||||||
constexpr auto kScalarUsub = "ScalarUsub";
|
constexpr auto kScalarUsub = "ScalarUsub";
|
||||||
|
constexpr auto kScalarEq = "ScalarEqual";
|
||||||
|
constexpr auto kScalarLt = "ScalarLess";
|
||||||
|
constexpr auto kScalarGt = "ScalarGreater";
|
||||||
|
constexpr auto kScalarLe = "ScalarLessEqual";
|
||||||
|
constexpr auto kScalarGe = "ScalarGreaterEqual";
|
||||||
|
constexpr auto kScalarBool = "ScalarBool";
|
||||||
|
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
|
||||||
|
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
|
||||||
constexpr auto kExp = "Exp";
|
constexpr auto kExp = "Exp";
|
||||||
constexpr auto kEqual = "Equal";
|
constexpr auto kEqual = "Equal";
|
||||||
constexpr auto kNotEqual = "NotEqual";
|
constexpr auto kNotEqual = "NotEqual";
|
||||||
|
@ -505,12 +513,15 @@ GVAR_DEF(PrimitivePtr, kPrimStringMul, std::make_shared<Primitive>(kStringMul));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimStringGetItem, std::make_shared<Primitive>(kStringGetItem));
|
GVAR_DEF(PrimitivePtr, kPrimStringGetItem, std::make_shared<Primitive>(kStringGetItem));
|
||||||
|
|
||||||
// Comparisons
|
// Comparisons
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarEq, std::make_shared<Primitive>("scalar_eq"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarEq, std::make_shared<Primitive>(kScalarEq));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarLt, std::make_shared<Primitive>("scalar_lt"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarLt, std::make_shared<Primitive>(kScalarLt));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarGt, std::make_shared<Primitive>("scalar_gt"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarGt, std::make_shared<Primitive>(kScalarGt));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarNe, std::make_shared<Primitive>("scalar_ne"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarNe, std::make_shared<Primitive>("scalar_ne"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarLe, std::make_shared<Primitive>("scalar_le"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarLe, std::make_shared<Primitive>(kScalarLe));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarGe, std::make_shared<Primitive>("scalar_ge"));
|
GVAR_DEF(PrimitivePtr, kPrimScalarGe, std::make_shared<Primitive>(kScalarGe));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimScalarBool, std::make_shared<Primitive>(kScalarBool));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimScalarBitwiseAnd, std::make_shared<Primitive>(kScalarBitwiseAnd));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimScalarBitwiseOr, std::make_shared<Primitive>(kScalarBitwiseOr));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBoolNot, std::make_shared<Primitive>("bool_not"));
|
GVAR_DEF(PrimitivePtr, kPrimBoolNot, std::make_shared<Primitive>("bool_not"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBoolAnd, std::make_shared<Primitive>("bool_and"));
|
GVAR_DEF(PrimitivePtr, kPrimBoolAnd, std::make_shared<Primitive>("bool_and"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBoolOr, std::make_shared<Primitive>("bool_or"));
|
GVAR_DEF(PrimitivePtr, kPrimBoolOr, std::make_shared<Primitive>("bool_or"));
|
||||||
|
@ -1503,9 +1514,7 @@ GVAR_DEF(PrimitivePtr, kPrimTileShape, std::make_shared<Primitive>("tile_shape")
|
||||||
GVAR_DEF(PrimitivePtr, kPrimGenerateShapeIndex, std::make_shared<Primitive>("generate_shape_index"));
|
GVAR_DEF(PrimitivePtr, kPrimGenerateShapeIndex, std::make_shared<Primitive>("generate_shape_index"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimGenerateInverseIndex, std::make_shared<Primitive>("generate_inverse_index"));
|
GVAR_DEF(PrimitivePtr, kPrimGenerateInverseIndex, std::make_shared<Primitive>("generate_inverse_index"));
|
||||||
|
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRealMakeList, std::make_shared<Primitive>(kRealMakeList));
|
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRealTupleGetItem, std::make_shared<Primitive>(kRealTupleGetItem));
|
GVAR_DEF(PrimitivePtr, kPrimRealTupleGetItem, std::make_shared<Primitive>(kRealTupleGetItem));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRealListGetItem, std::make_shared<Primitive>(kRealListGetItem));
|
|
||||||
GVAR_DEF(PrimitivePtr, kPrimListToTensor, std::make_shared<Primitive>(kListToTensor));
|
GVAR_DEF(PrimitivePtr, kPrimListToTensor, std::make_shared<Primitive>(kListToTensor));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimScalarToTensor, std::make_shared<Primitive>(kScalarToTensor));
|
GVAR_DEF(PrimitivePtr, kPrimScalarToTensor, std::make_shared<Primitive>(kScalarToTensor));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTensorToTuple, std::make_shared<Primitive>(kTensorToTuple));
|
GVAR_DEF(PrimitivePtr, kPrimTensorToTuple, std::make_shared<Primitive>(kTensorToTuple));
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "ops/real_makelist.h"
|
|
||||||
#include "abstract/ops/op_infer.h"
|
#include "abstract/ops/op_infer.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
|
@ -29,7 +28,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
MIND_API_OPERATOR_IMPL(MakeList, BaseOperator);
|
MIND_API_OPERATOR_IMPL(MakeList, BaseOperator);
|
||||||
MIND_API_OPERATOR_IMPL(RealMakeList, BaseOperator);
|
|
||||||
AbstractBasePtr MakeListInnerInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
AbstractBasePtr MakeListInnerInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return std::make_shared<abstract::AbstractList>(input_args);
|
return std::make_shared<abstract::AbstractList>(input_args);
|
||||||
}
|
}
|
||||||
|
@ -51,6 +49,5 @@ class MakeListInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(MakeList, prim::kPrimMakeList, MakeListInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(MakeList, prim::kPrimMakeList, MakeListInfer, false);
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealMakeList, prim::kPrimRealMakeList, MakeListInfer, false);
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -682,6 +683,57 @@ template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractList>(const Pri
|
||||||
template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractTuple>(const PrimitivePtr &primitive,
|
template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractTuple>(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T GetScalarValue(const std::string &op_name, const ValuePtr &elem) {
|
||||||
|
T res;
|
||||||
|
MS_EXCEPTION_IF_NULL(elem);
|
||||||
|
if (elem->isa<Int64Imm>()) {
|
||||||
|
auto elem_value = GetValue<int64_t>(elem);
|
||||||
|
res = static_cast<T>(elem_value);
|
||||||
|
} else if (elem->isa<Int32Imm>()) {
|
||||||
|
auto elem_value = GetValue<int32_t>(elem);
|
||||||
|
res = static_cast<T>(elem_value);
|
||||||
|
} else if (elem->isa<FP64Imm>()) {
|
||||||
|
auto elem_value = GetValue<double>(elem);
|
||||||
|
res = static_cast<T>(elem_value);
|
||||||
|
} else if (elem->isa<FP32Imm>()) {
|
||||||
|
auto elem_value = GetValue<float>(elem);
|
||||||
|
res = static_cast<T>(elem_value);
|
||||||
|
} else if (elem->isa<BoolImm>()) {
|
||||||
|
auto elem_value = GetValue<bool>(elem);
|
||||||
|
res = static_cast<T>(elem_value);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "For op '" << op_name
|
||||||
|
<< "' input must be [int32, int64, float32, float64, bool], but got " << elem->ToString();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template int64_t GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
template int32_t GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
template double GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
template float GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
template bool GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
|
||||||
|
TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name) {
|
||||||
|
static std::map<TypeId, size_t> prio_map = {{kNumberTypeFloat64, 1},
|
||||||
|
{kNumberTypeFloat32, 2},
|
||||||
|
{kNumberTypeInt64, 3},
|
||||||
|
{kNumberTypeInt32, 4},
|
||||||
|
{kNumberTypeBool, 5}};
|
||||||
|
auto x_iter = prio_map.find(x_type->type_id());
|
||||||
|
auto y_iter = prio_map.find(y_type->type_id());
|
||||||
|
if (x_iter == prio_map.end() || y_iter == prio_map.end()) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << op_name
|
||||||
|
<< "', the x and y type should be int or float, but got x type: " << x_type
|
||||||
|
<< " y type: " << y_type;
|
||||||
|
}
|
||||||
|
if (x_iter->second < y_iter->second) {
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
return y_type;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsValueKnown(const ValuePtr &value) {
|
bool IsValueKnown(const ValuePtr &value) {
|
||||||
// For now if the Abstract is a container of elements such as AbstractSequence and AbstractDictionary,
|
// For now if the Abstract is a container of elements such as AbstractSequence and AbstractDictionary,
|
||||||
// the BuildValue returns AnyValue if any one of the elements' value is AnyValue
|
// the BuildValue returns AnyValue if any one of the elements' value is AnyValue
|
||||||
|
|
|
@ -112,6 +112,11 @@ std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const Abstract
|
||||||
template <typename T>
|
template <typename T>
|
||||||
AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
|
AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T GetScalarValue(const std::string &op_name, const ValuePtr &elem);
|
||||||
|
|
||||||
|
TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name);
|
||||||
|
|
||||||
bool IsValueKnown(const ValuePtr &value);
|
bool IsValueKnown(const ValuePtr &value);
|
||||||
|
|
||||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,21 +14,23 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_ADD_H_
|
||||||
#define MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
|
#define MINDSPORE_CORE_OPS_SCALAR_ADD_H_
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
/// \brief RealListGetItem op is used to get list[index] value, list is a dynamic length list or index is variable
|
/// \brief ScalarAdd op is used to add between variable scalar.
|
||||||
class MIND_API RealListGetItem : public BaseOperator {
|
class MIND_API ScalarAdd : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(RealListGetItem);
|
MIND_API_BASE_MEMBER(ScalarAdd);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
RealListGetItem() : BaseOperator(prim::kRealListGetItem) { InitIOName({"input", "index"}, {"output"}); }
|
ScalarAdd() : BaseOperator(prim::kScalarAdd) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
};
|
};
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_REAL_LIST_GETITEM_H_
|
#endif // MINDSPORE_CORE_OPS_SCALAR_ADD_H_
|
|
@ -0,0 +1,325 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "abstract/ops/op_infer.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "ops/scalar_add.h"
|
||||||
|
#include "ops/scalar_sub.h"
|
||||||
|
#include "ops/scalar_mul.h"
|
||||||
|
#include "ops/scalar_div.h"
|
||||||
|
#include "ops/scalar_mod.h"
|
||||||
|
#include "ops/scalar_eq.h"
|
||||||
|
#include "ops/scalar_lt.h"
|
||||||
|
#include "ops/scalar_gt.h"
|
||||||
|
#include "ops/scalar_le.h"
|
||||||
|
#include "ops/scalar_ge.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr AddImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||||
|
T res;
|
||||||
|
if (__builtin_add_overflow(x, y, &res)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For prim '" << op_name
|
||||||
|
<< "' Overflow of the sum of two signed number x: " << std::to_string(x)
|
||||||
|
<< ", y: " << std::to_string(y) << ".";
|
||||||
|
}
|
||||||
|
return MakeValue(res);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return MakeValue(x + y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr SubImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||||
|
T res;
|
||||||
|
if (__builtin_sub_overflow(x, y, &res)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For prim '" << op_name
|
||||||
|
<< "' Overflow of the sub of two signed number x: " << std::to_string(x)
|
||||||
|
<< ", y: " << std::to_string(y) << ".";
|
||||||
|
}
|
||||||
|
return MakeValue(res);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return MakeValue(x - y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr MulImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||||
|
T res;
|
||||||
|
if (__builtin_mul_overflow(x, y, &res)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For prim '" << op_name
|
||||||
|
<< "' Overflow of the mul of two signed number x: " << std::to_string(x)
|
||||||
|
<< ", y: " << std::to_string(y) << ".";
|
||||||
|
}
|
||||||
|
return MakeValue(res);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return MakeValue(x * y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr DivImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
T zero = 0;
|
||||||
|
if (y == zero) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The divisor could not be zero. But the divisor is zero now.";
|
||||||
|
}
|
||||||
|
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||||
|
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For prim '" << op_name
|
||||||
|
<< "' Overflow of the div of two signed number x: " << std::to_string(x)
|
||||||
|
<< ", y: " << std::to_string(y) << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MakeValue(static_cast<float>(x) / static_cast<float>(y));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr ModImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
T zero = 0;
|
||||||
|
if (y == zero) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Cannot perform modulo operation on zero.";
|
||||||
|
}
|
||||||
|
if constexpr (std::is_signed<T>::value) {
|
||||||
|
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For prim '" << op_name
|
||||||
|
<< "' Overflow of the mod of two signed number x: " << std::to_string(x)
|
||||||
|
<< ", y: " << std::to_string(y) << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
T n = std::floor(static_cast<float>(x) / static_cast<float>(y));
|
||||||
|
T res = x - n * y;
|
||||||
|
return MakeValue(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr EqImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
if (std::isinf(static_cast<double>(x)) && std::isinf(static_cast<double>(y))) {
|
||||||
|
return MakeValue((x > 0 && y > 0) || (x < 0 && y < 0));
|
||||||
|
}
|
||||||
|
double error = static_cast<double>(x) - static_cast<double>(y);
|
||||||
|
error = fabs(error);
|
||||||
|
return MakeValue(error < DBL_EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr LtImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
return MakeValue(x < y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr GtImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
return MakeValue(x > y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr LeImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
return MakeValue(x <= y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ValuePtr GeImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
return MakeValue(x >= y);
|
||||||
|
}
|
||||||
|
|
||||||
|
using MathImplFunc = std::function<ValuePtr(const ValuePtr &, const ValuePtr &, const std::string &)>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
MathImplFunc ChooseFunc(const std::string &prim_name) {
|
||||||
|
std::map<std::string, MathImplFunc> infer_value_func_map = {
|
||||||
|
{prim::kScalarAdd, AddImpl<T>}, {prim::kScalarSub, SubImpl<T>}, {prim::kScalarMul, MulImpl<T>},
|
||||||
|
{prim::kScalarDiv, DivImpl<T>}, {prim::kScalarMod, ModImpl<T>}, {prim::kScalarEq, EqImpl<T>},
|
||||||
|
{prim::kScalarGt, GtImpl<T>}, {prim::kScalarLt, LtImpl<T>}, {prim::kScalarGe, GeImpl<T>},
|
||||||
|
{prim::kScalarLe, LeImpl<T>}};
|
||||||
|
auto iter = infer_value_func_map.find(prim_name);
|
||||||
|
if (iter == infer_value_func_map.end()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||||
|
<< "' don't support. Only support [Add, Sub, Mul, Div, Mod, Eq, Le, Ge, Lt, Gt]";
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ScalarArithmeticInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
constexpr size_t input_len = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
|
||||||
|
auto elem_x = input_args[0];
|
||||||
|
auto elem_y = input_args[kIndex1];
|
||||||
|
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
|
||||||
|
<< " and y: " << elem_y->ToString();
|
||||||
|
}
|
||||||
|
return abstract::kNoShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
auto x_type = input_args[0]->BuildType();
|
||||||
|
auto y_type = input_args[kIndex1]->BuildType();
|
||||||
|
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64};
|
||||||
|
std::set<std::string> compare_ops = {prim::kScalarEq, prim::kScalarGe, prim::kScalarGt, prim::kScalarLt,
|
||||||
|
prim::kScalarLe};
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass("y_dtype", y_type, check_types, prim_name);
|
||||||
|
auto iter = compare_ops.find(prim_name);
|
||||||
|
if (prim_name == prim::kScalarDiv) {
|
||||||
|
return kFloat32;
|
||||||
|
}
|
||||||
|
if (iter != compare_ops.end()) {
|
||||||
|
return kBool;
|
||||||
|
}
|
||||||
|
return HighPriorityType(x_type, y_type, prim_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
constexpr size_t input_num = 2;
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
constexpr size_t x_index = 0;
|
||||||
|
constexpr size_t y_index = 1;
|
||||||
|
auto elem_x = input_args[x_index];
|
||||||
|
auto elem_y = input_args[y_index];
|
||||||
|
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
|
||||||
|
<< " and y: " << elem_y->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_value = elem_x->BuildValue();
|
||||||
|
auto y_value = elem_y->BuildValue();
|
||||||
|
if (x_value == kAnyValue || y_value == kAnyValue) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto x_type = input_args[x_index]->BuildType();
|
||||||
|
auto y_type = input_args[y_index]->BuildType();
|
||||||
|
auto res_type = HighPriorityType(x_type, y_type, op_name);
|
||||||
|
ValuePtr result;
|
||||||
|
switch (res_type->type_id()) {
|
||||||
|
case kNumberTypeInt32: {
|
||||||
|
auto func = ChooseFunc<int32_t>(op_name);
|
||||||
|
result = func(x_value, y_value, op_name);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeInt64: {
|
||||||
|
auto func = ChooseFunc<int64_t>(op_name);
|
||||||
|
result = func(x_value, y_value, op_name);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeFloat32: {
|
||||||
|
auto func = ChooseFunc<float>(op_name);
|
||||||
|
result = func(x_value, y_value, op_name);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeFloat64: {
|
||||||
|
auto func = ChooseFunc<double>(op_name);
|
||||||
|
result = func(x_value, y_value, op_name);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name
|
||||||
|
<< "', the supported type is in the list: [int32, int64, float32, float64], but got "
|
||||||
|
<< res_type->ToString() << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<ValuePtr(const ValuePtr &, const ValuePtr &, const std::string &)> infer_value_func_;
|
||||||
|
};
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarAdd, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarSub, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarMul, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarDiv, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarMod, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarEqual, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarGreater, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarGreaterEqual, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarLess, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarLessEqual, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarAdd, prim::kPrimScalarAdd, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarSub, prim::kPrimScalarSub, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMul, prim::kPrimScalarMul, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarDiv, prim::kPrimScalarDiv, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMod, prim::kPrimScalarMod, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarEqual, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreater, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreaterEqual, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLess, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLessEqual, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "abstract/ops/op_infer.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "ops/scalar_bitwise_or.h"
|
||||||
|
#include "ops/scalar_bitwise_and.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
template <typename T>
|
||||||
|
T BitwiseImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||||
|
MS_EXCEPTION_IF_NULL(x_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
auto x = GetScalarValue<T>(op_name, x_value);
|
||||||
|
auto y = GetScalarValue<T>(op_name, y_value);
|
||||||
|
if (op_name == prim::kScalarBitwiseAnd) {
|
||||||
|
return x & y;
|
||||||
|
} else {
|
||||||
|
return x | y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ScalarBitwiseInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
constexpr size_t input_len = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
|
||||||
|
auto elem_x = input_args[0];
|
||||||
|
auto elem_y = input_args[kIndex1];
|
||||||
|
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
|
||||||
|
<< " and y: " << elem_y->ToString();
|
||||||
|
}
|
||||||
|
return abstract::kNoShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
auto x_type = input_args[0]->BuildType();
|
||||||
|
auto y_type = input_args[kIndex1]->BuildType();
|
||||||
|
std::set<TypePtr> check_types = {kInt32, kInt64, kBool};
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass("y_dtype", y_type, check_types, prim_name);
|
||||||
|
return HighPriorityType(x_type, y_type, prim_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
constexpr size_t input_num = 2;
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
constexpr size_t x_index = 0;
|
||||||
|
constexpr size_t y_index = 1;
|
||||||
|
auto elem_x = input_args[x_index];
|
||||||
|
auto elem_y = input_args[y_index];
|
||||||
|
if (!elem_x->isa<abstract::AbstractScalar>() && !elem_y->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got x: " << elem_x->ToString()
|
||||||
|
<< " and y: " << elem_y->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_value = elem_x->BuildValue();
|
||||||
|
auto y_value = elem_y->BuildValue();
|
||||||
|
if (x_value == kAnyValue || y_value == kAnyValue) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto res_type = InferType(primitive, input_args);
|
||||||
|
ValuePtr res;
|
||||||
|
switch (res_type->type_id()) {
|
||||||
|
case kNumberTypeInt32: {
|
||||||
|
res = MakeValue(BitwiseImpl<int32_t>(x_value, y_value, op_name));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeInt64: {
|
||||||
|
res = MakeValue(BitwiseImpl<int64_t>(x_value, y_value, op_name));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeBool: {
|
||||||
|
res = MakeValue(BitwiseImpl<bool>(x_value, y_value, op_name));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name
|
||||||
|
<< "', the supported type is in the list: [int32, int64, bool], but got "
|
||||||
|
<< res_type->ToString() << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarBitwiseOr, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarBitwiseAnd, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseOr, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseAnd, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief
|
||||||
|
class MIND_API ScalarBitwiseAnd : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarBitwiseAnd);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarBitwiseAnd() : BaseOperator(prim::kScalarBitwiseAnd) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_BITWISE_AND_H_
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,23 +14,23 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_
|
||||||
#define MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
|
#define MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
/// \brief RealMakeList op
|
/// \brief
|
||||||
class MIND_API RealMakeList : public BaseOperator {
|
class MIND_API ScalarBitwiseOr : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(RealMakeList);
|
MIND_API_BASE_MEMBER(ScalarBitwiseOr);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
RealMakeList() : BaseOperator(prim::kRealMakeList) {}
|
ScalarBitwiseOr() : BaseOperator(prim::kScalarBitwiseOr) {}
|
||||||
/// \brief Init.
|
/// \brief Init.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_REAL_MAKELIST_H_
|
#endif // MINDSPORE_CORE_OPS_SCALAR_BITWISE_OR_H_
|
|
@ -0,0 +1,110 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/scalar_bool.h"
|
||||||
|
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "abstract/ops/op_infer.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
class ScalarBoolInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
constexpr size_t input_len = 1;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, op_name);
|
||||||
|
auto elem = input_args[0];
|
||||||
|
if (!elem->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got : " << elem->ToString();
|
||||||
|
}
|
||||||
|
return abstract::kNoShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
auto x_type = input_args[0]->BuildType();
|
||||||
|
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64, kBool};
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
|
||||||
|
return kBool;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
constexpr size_t input_num = 1;
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
auto elem = input_args[0];
|
||||||
|
if (!elem->isa<abstract::AbstractScalar>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input should be scalar but got : " << elem->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_valueptr = elem->BuildValue();
|
||||||
|
if (x_valueptr == kAnyValue) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto x_type = input_args[0]->BuildType();
|
||||||
|
bool res;
|
||||||
|
switch (x_type->type_id()) {
|
||||||
|
case kNumberTypeInt32: {
|
||||||
|
auto elem_value = GetValue<int32_t>(x_valueptr);
|
||||||
|
res = static_cast<bool>(elem_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeInt64: {
|
||||||
|
auto elem_value = GetValue<int64_t>(x_valueptr);
|
||||||
|
res = static_cast<bool>(elem_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeFloat32: {
|
||||||
|
auto elem_value = GetValue<float>(x_valueptr);
|
||||||
|
res = static_cast<bool>(elem_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeFloat64: {
|
||||||
|
auto elem_value = GetValue<double>(x_valueptr);
|
||||||
|
res = static_cast<bool>(elem_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNumberTypeBool: {
|
||||||
|
auto elem_value = GetValue<bool>(x_valueptr);
|
||||||
|
res = static_cast<bool>(elem_value);
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
MS_EXCEPTION(TypeError)
|
||||||
|
<< "For '" << op_name
|
||||||
|
<< "', the supported type is in the list: [int32, int64, float32, float64, bool], but got "
|
||||||
|
<< x_type->ToString() << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MakeValue(res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MIND_API_OPERATOR_IMPL(ScalarBool, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBool, prim::kPrimScalarBool, ScalarBoolInfer, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_BOOL_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_BOOL_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarBool op is used to calculate the input true or false.
|
||||||
|
class MIND_API ScalarBool : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarBool);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarBool() : BaseOperator(prim::kScalarBool) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_BOOL_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_DIV_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_DIV_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarDiv op is used to div between variable scalar.
|
||||||
|
class MIND_API ScalarDiv : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarDiv);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarDiv() : BaseOperator(prim::kScalarDiv) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_DIV_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_EQ_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_EQ_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarEqual op is used to judge equal between variable scalar.
|
||||||
|
class MIND_API ScalarEqual : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarEqual);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarEqual() : BaseOperator(prim::kScalarEq) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_EQ_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_GE_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_GE_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarGreaterEqual op is used to judge greaterEqual between variable scalar.
|
||||||
|
class MIND_API ScalarGreaterEqual : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarGreaterEqual);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarGreaterEqual() : BaseOperator(prim::kScalarGe) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_GE_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_GT_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_GT_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarGreater op is used to judge greater between variable scalar.
|
||||||
|
class MIND_API ScalarGreater : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarGreater);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarGreater() : BaseOperator(prim::kScalarGt) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_GT_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_LE_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_LE_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarLessEqual op is used to judge lessEqual between variable scalar.
|
||||||
|
class MIND_API ScalarLessEqual : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarLessEqual);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarLessEqual() : BaseOperator(prim::kScalarLe) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_LE_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_LT_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_LT_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarLess op is used to judge less between variable scalar.
|
||||||
|
class MIND_API ScalarLess : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarLess);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarLess() : BaseOperator(prim::kScalarLt) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_LT_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_MOD_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_MOD_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarMod op is used to mod between variable scalar.
|
||||||
|
class MIND_API ScalarMod : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarMod);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarMod() : BaseOperator(prim::kScalarMod) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_MOD_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_MUL_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_MUL_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarMul op is used to Mul between variable scalar.
|
||||||
|
class MIND_API ScalarMul : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarMul);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarMul() : BaseOperator(prim::kScalarMul) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_Mul_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCALAR_SUB_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCALAR_SUB_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief ScalarSub op is used to add between variable scalar.
|
||||||
|
class MIND_API ScalarSub : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScalarSub);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScalarSub() : BaseOperator(prim::kScalarSub) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCALAR_SUB_H_
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -88,10 +88,8 @@ class ScalarToTensorInfer : public abstract::OpInferBase {
|
||||||
MS_EXCEPTION_IF_NULL(attr);
|
MS_EXCEPTION_IF_NULL(attr);
|
||||||
}
|
}
|
||||||
if (!attr->isa<Type>()) {
|
if (!attr->isa<Type>()) {
|
||||||
MS_EXCEPTION(TypeError)
|
MS_EXCEPTION(TypeError) << "For '" << prim_name << "the second input must be a `Type`, but got "
|
||||||
<< "For '" << prim_name
|
<< attr->type_name();
|
||||||
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
|
|
||||||
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
|
|
||||||
}
|
}
|
||||||
auto output_dtype = attr->cast<TypePtr>();
|
auto output_dtype = attr->cast<TypePtr>();
|
||||||
|
|
||||||
|
@ -128,5 +126,6 @@ class ScalarToTensorInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
MIND_API_OPERATOR_IMPL(ScalarToTensor, BaseOperator);
|
MIND_API_OPERATOR_IMPL(ScalarToTensor, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarToTensor, prim::kPrimScalarToTensor, ScalarToTensorInfer, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,7 +20,6 @@
|
||||||
#include "ops/tuple_get_item.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "ops/list_getitem.h"
|
#include "ops/list_getitem.h"
|
||||||
#include "ops/real_tuple_getitem.h"
|
#include "ops/real_tuple_getitem.h"
|
||||||
#include "ops/real_list_getitem.h"
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "abstract/param_validator.h"
|
#include "abstract/param_validator.h"
|
||||||
#include "abstract/ops/op_infer.h"
|
#include "abstract/ops/op_infer.h"
|
||||||
|
@ -104,10 +103,8 @@ class SequenceGetItemInfer : public abstract::OpInferBase {
|
||||||
MIND_API_OPERATOR_IMPL(TupleGetItem, BaseOperator);
|
MIND_API_OPERATOR_IMPL(TupleGetItem, BaseOperator);
|
||||||
MIND_API_OPERATOR_IMPL(RealTupleGetItem, BaseOperator);
|
MIND_API_OPERATOR_IMPL(RealTupleGetItem, BaseOperator);
|
||||||
MIND_API_OPERATOR_IMPL(ListGetItem, BaseOperator);
|
MIND_API_OPERATOR_IMPL(ListGetItem, BaseOperator);
|
||||||
MIND_API_OPERATOR_IMPL(RealListGetItem, BaseOperator);
|
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(TupleGetItem, prim::kPrimTupleGetItem, SequenceGetItemInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(TupleGetItem, prim::kPrimTupleGetItem, SequenceGetItemInfer, false);
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealTupleGetItem, prim::kPrimRealTupleGetItem, SequenceGetItemInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealTupleGetItem, prim::kPrimRealTupleGetItem, SequenceGetItemInfer, false);
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ListGetItem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ListGetItem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealListGetItem, prim::kPrimRealListGetItem, SequenceGetItemInfer, false);
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -65,7 +65,7 @@ def TupleGetItem(x, index):
|
||||||
return x[index]
|
return x[index]
|
||||||
|
|
||||||
|
|
||||||
def scalar_gt(x, y):
|
def ScalarGreater(x, y):
|
||||||
"""Implement `scalar_gt`."""
|
"""Implement `scalar_gt`."""
|
||||||
return x > y
|
return x > y
|
||||||
|
|
||||||
|
@ -75,17 +75,17 @@ def scalar_ne(x, y):
|
||||||
return x != y
|
return x != y
|
||||||
|
|
||||||
|
|
||||||
def scalar_eq(x, y):
|
def ScalarEqual(x, y):
|
||||||
"""Implement `scalar_eq`."""
|
"""Implement `scalar_eq`."""
|
||||||
return x == y
|
return x == y
|
||||||
|
|
||||||
|
|
||||||
def scalar_le(x, y):
|
def ScalarLessEqual(x, y):
|
||||||
"""Implement `scalar_le`."""
|
"""Implement `scalar_le`."""
|
||||||
return x <= y
|
return x <= y
|
||||||
|
|
||||||
|
|
||||||
def scalar_lt(x, y):
|
def ScalarLess(x, y):
|
||||||
"""Implement `scalar_lt`."""
|
"""Implement `scalar_lt`."""
|
||||||
return x < y
|
return x < y
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
|
||||||
if parameter_divisibility:
|
if parameter_divisibility:
|
||||||
delta_w = P.Squeeze()(delta_w)
|
delta_w = P.Squeeze()(delta_w)
|
||||||
ori_len = F.shape(delta_w)[0]
|
ori_len = F.shape(delta_w)[0]
|
||||||
divide_len = ori_len / 2
|
divide_len = ori_len // 2
|
||||||
left_part = delta_w[:divide_len]
|
left_part = delta_w[:divide_len]
|
||||||
right_part = delta_w[divide_len:]
|
right_part = delta_w[divide_len:]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -18,12 +18,7 @@
|
||||||
""" Define constants"""
|
""" Define constants"""
|
||||||
|
|
||||||
# Arithmetic
|
# Arithmetic
|
||||||
kScalarAdd = "ScalarAdd"
|
|
||||||
kScalarSub = "ScalarSub"
|
|
||||||
kScalarMul = "ScalarMul"
|
|
||||||
kScalarDiv = "ScalarDiv"
|
|
||||||
kScalarFloordiv = "ScalarFloordiv"
|
kScalarFloordiv = "ScalarFloordiv"
|
||||||
kScalarMod = "ScalarMod"
|
|
||||||
kScalarPow = "ScalarPow"
|
kScalarPow = "ScalarPow"
|
||||||
kScalarTrunc = "ScalarTrunc"
|
kScalarTrunc = "ScalarTrunc"
|
||||||
kScalarFloor = "ScalarFloor"
|
kScalarFloor = "ScalarFloor"
|
||||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops.composite import multitype_ops as C
|
from mindspore.ops.composite import multitype_ops as C
|
||||||
from mindspore.ops._grad.grad_base import bprops
|
from mindspore.ops._grad.grad_base import bprops
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops.operations import _scalar_ops
|
||||||
|
|
||||||
get_dtype = P.DType()
|
get_dtype = P.DType()
|
||||||
# Unused parameters are placeholders.
|
# Unused parameters are placeholders.
|
||||||
|
@ -35,25 +36,25 @@ def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
|
||||||
return F.zeros_like(x), F.zeros_like(y), dz
|
return F.zeros_like(x), F.zeros_like(y), dz
|
||||||
|
|
||||||
|
|
||||||
@bprops.register(_constants.kScalarAdd)
|
@bprops.register(_scalar_ops.ScalarAdd)
|
||||||
def bprop_scalar_add(x, y, out, dout):
|
def bprop_scalar_add(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_add`."""
|
"""Backpropagator for primitive `scalar_add`."""
|
||||||
return dout, dout
|
return dout, dout
|
||||||
|
|
||||||
|
|
||||||
@bprops.register(_constants.kScalarMul)
|
@bprops.register(_scalar_ops.ScalarMul)
|
||||||
def bprop_scalar_mul(x, y, out, dout):
|
def bprop_scalar_mul(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_mul`."""
|
"""Backpropagator for primitive `scalar_mul`."""
|
||||||
return dout * y, dout * x
|
return dout * y, dout * x
|
||||||
|
|
||||||
|
|
||||||
@bprops.register(_constants.kScalarSub)
|
@bprops.register(_scalar_ops.ScalarSub)
|
||||||
def bprop_scalar_sub(x, y, out, dout):
|
def bprop_scalar_sub(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_sub`."""
|
"""Backpropagator for primitive `scalar_sub`."""
|
||||||
return dout, -dout
|
return dout, -dout
|
||||||
|
|
||||||
|
|
||||||
@bprops.register(_constants.kScalarDiv)
|
@bprops.register(_scalar_ops.ScalarDiv)
|
||||||
def bprop_scalar_div(x, y, out, dout):
|
def bprop_scalar_div(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_div`."""
|
"""Backpropagator for primitive `scalar_div`."""
|
||||||
return dout / y, (-dout) * (out / y)
|
return dout / y, (-dout) * (out / y)
|
||||||
|
@ -187,16 +188,16 @@ def bprop_mutable(x, out, dout):
|
||||||
return (dout,)
|
return (dout,)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_gt")
|
@bprops.register(_scalar_ops.ScalarGreater)
|
||||||
@bprops.register("scalar_lt")
|
@bprops.register(_scalar_ops.ScalarLess)
|
||||||
@bprops.register("scalar_ge")
|
@bprops.register(_scalar_ops.ScalarGreaterEqual)
|
||||||
@bprops.register("scalar_le")
|
@bprops.register(_scalar_ops.ScalarLessEqual)
|
||||||
@bprops.register("scalar_eq")
|
@bprops.register(_scalar_ops.ScalarEqual)
|
||||||
@bprops.register("scalar_ne")
|
@bprops.register("scalar_ne")
|
||||||
@bprops.register("bool_and")
|
@bprops.register("bool_and")
|
||||||
@bprops.register("bool_or")
|
@bprops.register("bool_or")
|
||||||
@bprops.register("bit_and")
|
@bprops.register(_scalar_ops.ScalarBitwiseAnd)
|
||||||
@bprops.register("bit_or")
|
@bprops.register(_scalar_ops.ScalarBitwiseOr)
|
||||||
@bprops.register("bit_xor")
|
@bprops.register("bit_xor")
|
||||||
@bprops.register("bit_left_shift")
|
@bprops.register("bit_left_shift")
|
||||||
@bprops.register("bit_right_shift")
|
@bprops.register("bit_right_shift")
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops.function import *
|
||||||
from mindspore.ops.function.array_func import narrow
|
from mindspore.ops.function.array_func import narrow
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.primitive import Primitive
|
from mindspore.ops.primitive import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops
|
from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops, _scalar_ops
|
||||||
from mindspore.ops.operations.math_ops import Median
|
from mindspore.ops.operations.math_ops import Median
|
||||||
from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
|
from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
|
||||||
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
|
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
|
||||||
|
@ -55,6 +55,17 @@ partial = P.Partial()
|
||||||
# depend: mount a node to another node
|
# depend: mount a node to another node
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
identity = P.identity()
|
identity = P.identity()
|
||||||
|
# tuple/list/scalar ops
|
||||||
|
scalar_div = _scalar_ops.ScalarDiv()
|
||||||
|
scalar_mod = _scalar_ops.ScalarMod()
|
||||||
|
scalar_add = _scalar_ops.ScalarAdd()
|
||||||
|
scalar_mul = _scalar_ops.ScalarMul()
|
||||||
|
scalar_sub = _scalar_ops.ScalarSub()
|
||||||
|
scalar_gt = _scalar_ops.ScalarGreater()
|
||||||
|
scalar_ge = _scalar_ops.ScalarGreaterEqual()
|
||||||
|
scalar_le = _scalar_ops.ScalarLessEqual()
|
||||||
|
scalar_lt = _scalar_ops.ScalarLess()
|
||||||
|
scalar_eq = _scalar_ops.ScalarEqual()
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||||
|
@ -73,22 +84,12 @@ make_list = Primitive('make_list')
|
||||||
make_slice = Primitive('make_slice')
|
make_slice = Primitive('make_slice')
|
||||||
tuple_equal = Primitive("tuple_equal")
|
tuple_equal = Primitive("tuple_equal")
|
||||||
list_equal = Primitive("list_equal")
|
list_equal = Primitive("list_equal")
|
||||||
scalar_add = Primitive(_constants.kScalarAdd)
|
|
||||||
scalar_mul = Primitive(_constants.kScalarMul)
|
|
||||||
scalar_sub = Primitive(_constants.kScalarSub)
|
|
||||||
scalar_div = Primitive(_constants.kScalarDiv)
|
|
||||||
scalar_floordiv = Primitive(_constants.kScalarFloordiv)
|
scalar_floordiv = Primitive(_constants.kScalarFloordiv)
|
||||||
scalar_log = Primitive('scalar_log')
|
scalar_log = Primitive('scalar_log')
|
||||||
scalar_pow = Primitive(_constants.kScalarPow)
|
scalar_pow = Primitive(_constants.kScalarPow)
|
||||||
scalar_gt = Primitive('scalar_gt')
|
|
||||||
scalar_ge = Primitive('scalar_ge')
|
|
||||||
scalar_le = Primitive('scalar_le')
|
|
||||||
scalar_lt = Primitive('scalar_lt')
|
|
||||||
scalar_eq = Primitive('scalar_eq')
|
|
||||||
scalar_ne = Primitive('scalar_ne')
|
scalar_ne = Primitive('scalar_ne')
|
||||||
scalar_uadd = Primitive(_constants.kScalarUadd)
|
scalar_uadd = Primitive(_constants.kScalarUadd)
|
||||||
scalar_usub = Primitive(_constants.kScalarUsub)
|
scalar_usub = Primitive(_constants.kScalarUsub)
|
||||||
scalar_mod = Primitive(_constants.kScalarMod)
|
|
||||||
string_eq = Primitive('string_eq')
|
string_eq = Primitive('string_eq')
|
||||||
string_concat = Primitive('string_concat')
|
string_concat = Primitive('string_concat')
|
||||||
bool_not = Primitive("bool_not")
|
bool_not = Primitive("bool_not")
|
||||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
||||||
from mindspore.common import Tensor
|
from mindspore.common import Tensor
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops.operations.array_ops import Cast
|
from mindspore.ops.operations.array_ops import Cast
|
||||||
|
from mindspore.ops.operations._scalar_ops import ScalarBitwiseOr, ScalarBitwiseAnd
|
||||||
from mindspore.ops import signature as sig
|
from mindspore.ops import signature as sig
|
||||||
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
||||||
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
|
||||||
|
@ -35,8 +36,8 @@ from mindspore.common._register_for_adapter import ms_adapter_registry
|
||||||
|
|
||||||
|
|
||||||
# Bit operation
|
# Bit operation
|
||||||
bit_and = Primitive("bit_and")
|
bit_and = ScalarBitwiseAnd()
|
||||||
bit_or = Primitive("bit_or")
|
bit_or = ScalarBitwiseOr()
|
||||||
bit_xor = Primitive("bit_xor")
|
bit_xor = Primitive("bit_xor")
|
||||||
bit_left_shift = Primitive("bit_left_shift")
|
bit_left_shift = Primitive("bit_left_shift")
|
||||||
bit_right_shift = Primitive("bit_right_shift")
|
bit_right_shift = Primitive("bit_right_shift")
|
||||||
|
|
|
@ -0,0 +1,367 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Operations for sequence"""
|
||||||
|
|
||||||
|
from mindspore.ops.primitive import Primitive, prim_attr_register
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarDiv(Primitive):
|
||||||
|
r"""
|
||||||
|
Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_{i} = \frac{x_i}{y_i}
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '/' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is float.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
ValueError: If `y` is 0.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarDiv"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarAdd(Primitive):
|
||||||
|
r"""
|
||||||
|
Adds two input scalar.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '+' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarAdd"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarSub(Primitive):
|
||||||
|
r"""
|
||||||
|
Subtracts the second input Scalar from the first input Scalar.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '-' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarSub"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarMul(Primitive):
|
||||||
|
r"""
|
||||||
|
Muls two input scalar.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '+' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarMul"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarEqual(Primitive):
|
||||||
|
r"""
|
||||||
|
Computes the equivalence between two Scalars.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '==' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarMul"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarGreater(Primitive):
|
||||||
|
r"""
|
||||||
|
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '>' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarGreater"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarLess(Primitive):
|
||||||
|
r"""
|
||||||
|
Computes the boolean value of :math:`x < y`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '<' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarLess"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarGreaterEqual(Primitive):
|
||||||
|
r"""
|
||||||
|
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '>=' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarGreaterEqual"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarLessEqual(Primitive):
|
||||||
|
r"""
|
||||||
|
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '<=' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type of scalar is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarLessEqual"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarMod(Primitive):
|
||||||
|
r"""
|
||||||
|
Computes the remainder of dividing the first input scalar by the second input scalar element-wise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_{i} = x_{i} \text{ % } y_{i}
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '%' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarMod"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarBool(Primitive):
|
||||||
|
r"""
|
||||||
|
Computes the input scalar true or false.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type is bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarBool"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarBitwiseAnd(Primitive):
|
||||||
|
r"""
|
||||||
|
Returns bitwise `and` of two scalars.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_{i} = x_{i} \text{ % } y_{i}
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '%' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar, the type can be int or bool.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar, the type can be int or bool.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarMod"""
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarBitwiseOr(Primitive):
|
||||||
|
r"""
|
||||||
|
Returns bitwise `or` of two scalars.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The inputs can be constant/variable value. Usage is the same as '|' in Python.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Scalar) - A constant or variable scalar, the type can be int or bool.
|
||||||
|
- **y** (Scalar) - A constant or variable scalar, the type can be int or bool.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Scalar, the type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` and `y` are not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize ScalarMod"""
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -43,7 +43,6 @@ class ListAppend(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ListAppend"""
|
"""Initialize ListAppend"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_data', 'target'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_data', 'target'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +74,6 @@ class SequenceSlice(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize SequenceSlice"""
|
"""Initialize SequenceSlice"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['seq', 'start', 'stop', 'step'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['seq', 'start', 'stop', 'step'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +105,6 @@ class SequenceSliceSetItem(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize SequenceSliceSetItem"""
|
"""Initialize SequenceSliceSetItem"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['seq', 'target', 'start', 'stop', 'step'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['seq', 'target', 'start', 'stop', 'step'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,7 +133,6 @@ class SequenceAdd(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize SequenceAdd"""
|
"""Initialize SequenceAdd"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_1', 'input_2'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_1', 'input_2'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,7 +164,6 @@ class TupleToTensor(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize TupleToTensor"""
|
"""Initialize TupleToTensor"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_tuple', 'dtype'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_tuple', 'dtype'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,7 +195,6 @@ class ListToTensor(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ListToTensor"""
|
"""Initialize ListToTensor"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_list', 'dtype'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_list', 'dtype'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -228,7 +222,6 @@ class TensorToTuple(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize TensorToTuple"""
|
"""Initialize TensorToTuple"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,7 +249,6 @@ class TensorToList(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize TensorToList"""
|
"""Initialize TensorToList"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -284,7 +276,6 @@ class TensorToScalar(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize TensorToScalar"""
|
"""Initialize TensorToScalar"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['input_tensor'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -312,7 +303,6 @@ class SequenceCount(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize SequenceCount"""
|
"""Initialize SequenceCount"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['sequence', 'target'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['sequence', 'target'], outputs=['output_data'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,5 +331,4 @@ class SequenceMul(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize SequenceMul"""
|
"""Initialize SequenceMul"""
|
||||||
self.add_prim_attr("primitive_target", "CPU")
|
|
||||||
self.init_prim_io_names(inputs=['sequence', 'scalar'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['sequence', 'scalar'], outputs=['output_data'])
|
||||||
|
|
|
@ -1730,9 +1730,9 @@ class ScalarToTensor(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data'])
|
||||||
|
|
||||||
def infer_value(self, x, dtype=mstype.float32):
|
def __call__(self, x, dtype=mstype.float32):
|
||||||
validator.check_value_type("x", x, [int, float], self.name)
|
validator.check_value_type("x", x, [int, float], self.name)
|
||||||
validator.check_subclass("dtype", dtype, mstype.number, self.name)
|
validator.check_subclass("dtype", dtype, mstype.number, self.name)
|
||||||
data_type = mstype.dtype_to_nptype(dtype)
|
data_type = mstype.dtype_to_nptype(dtype)
|
||||||
|
|
|
@ -146,12 +146,12 @@ class YOLOv3(nn.Cell):
|
||||||
con1, big_object_output = self.backblock0(feature_map3)
|
con1, big_object_output = self.backblock0(feature_map3)
|
||||||
|
|
||||||
con1 = self.conv1(con1)
|
con1 = self.conv1(con1)
|
||||||
ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1)
|
ups1 = P.ResizeNearestNeighbor((img_hight // 16, img_width // 16))(con1)
|
||||||
con1 = self.concat((ups1, feature_map2))
|
con1 = self.concat((ups1, feature_map2))
|
||||||
con2, medium_object_output = self.backblock1(con1)
|
con2, medium_object_output = self.backblock1(con1)
|
||||||
|
|
||||||
con2 = self.conv2(con2)
|
con2 = self.conv2(con2)
|
||||||
ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2)
|
ups2 = P.ResizeNearestNeighbor((img_hight // 8, img_width // 8))(con2)
|
||||||
con3 = self.concat((ups2, feature_map1))
|
con3 = self.concat((ups2, feature_map1))
|
||||||
_, small_object_output = self.backblock2(con3)
|
_, small_object_output = self.backblock2(con3)
|
||||||
|
|
||||||
|
|
|
@ -112,17 +112,17 @@ TEST_F(TestOps, ScalarTanTest) {
|
||||||
|
|
||||||
// Comparisons
|
// Comparisons
|
||||||
TEST_F(TestOps, ScalarEqTest) {
|
TEST_F(TestOps, ScalarEqTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_eq");
|
auto prim = std::make_shared<Primitive>(prim::kScalarEq);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarEq->name());
|
ASSERT_EQ(prim->name(), kPrimScalarEq->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarLtTest) {
|
TEST_F(TestOps, ScalarLtTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_lt");
|
auto prim = std::make_shared<Primitive>(prim::kScalarLt);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarLt->name());
|
ASSERT_EQ(prim->name(), kPrimScalarLt->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarGtTest) {
|
TEST_F(TestOps, ScalarGtTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_gt");
|
auto prim = std::make_shared<Primitive>(prim::kScalarGt);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarGt->name());
|
ASSERT_EQ(prim->name(), kPrimScalarGt->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,12 +132,12 @@ TEST_F(TestOps, ScalarNeTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarLeTest) {
|
TEST_F(TestOps, ScalarLeTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_le");
|
auto prim = std::make_shared<Primitive>(prim::kScalarLe);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarLe->name());
|
ASSERT_EQ(prim->name(), kPrimScalarLe->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarGeTest) {
|
TEST_F(TestOps, ScalarGeTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_ge");
|
auto prim = std::make_shared<Primitive>(prim::kScalarGe);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarGe->name());
|
ASSERT_EQ(prim->name(), kPrimScalarGe->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(static_cast<int64_t>(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestInferGraph : public UT::Common {
|
class TestInferGraph : public UT::Common {
|
||||||
|
@ -273,7 +273,7 @@ TEST_F(TestInferGraph, test_inferred) {
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).eval_result->abstract();
|
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).eval_result->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(static_cast<int64_t>(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestInferGraph, test_context) {
|
TEST_F(TestInferGraph, test_context) {
|
||||||
|
@ -352,6 +352,7 @@ void TestInferMetaGraph::TearDown() {
|
||||||
TEST_F(TestInferMetaGraph, test_inferred) {
|
TEST_F(TestInferMetaGraph, test_inferred) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
int64_t v1 = 1;
|
int64_t v1 = 1;
|
||||||
|
int64_t res = 2;
|
||||||
std::cout << "Begin TestInferGraph." << std::endl;
|
std::cout << "Begin TestInferGraph." << std::endl;
|
||||||
std::cout << func_graph_->get_return()->ToString() << std::endl;
|
std::cout << func_graph_->get_return()->ToString() << std::endl;
|
||||||
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
||||||
|
@ -359,7 +360,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).eval_result->abstract();
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).eval_result->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(*abs_base_got->BuildValue() == *MakeValue(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestInferUniform : public UT::Common {
|
class TestInferUniform : public UT::Common {
|
||||||
|
|
|
@ -13,11 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" Test for GraphCloner """
|
""" Test for GraphCloner """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import _constants as Constants
|
|
||||||
|
|
||||||
scala_add = Primitive(Constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
scalar_mul = Primitive(Constants.kScalarMul)
|
scalar_mul = F.scalar_mul
|
||||||
|
|
||||||
|
|
||||||
def test_clone_simple():
|
def test_clone_simple():
|
||||||
|
|
|
@ -17,11 +17,11 @@ import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import _constants as Constants
|
from mindspore.ops import functional as F
|
||||||
from tests.ut.python.model.resnet import resnet50
|
from tests.ut.python.model.resnet import resnet50
|
||||||
|
|
||||||
|
|
||||||
scala_add = Primitive(Constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
|
|
||||||
|
|
||||||
def scalar_add(x, y):
|
def scalar_add(x, y):
|
||||||
|
|
|
@ -29,8 +29,8 @@ from mindspore.ops.operations import _grad_ops as G
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
scalar_add = Primitive(Constants.kScalarAdd)
|
scalar_add = F.scalar_add
|
||||||
scalar_mul = Primitive(Constants.kScalarMul)
|
scalar_mul = F.scalar_mul
|
||||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
switch = Primitive('Switch')
|
switch = Primitive('Switch')
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ def test_inline_while(tag):
|
||||||
def test_cse(tag):
|
def test_cse(tag):
|
||||||
""" test_cse """
|
""" test_cse """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
scalar_div = Primitive(Constants.kScalarDiv)
|
scalar_div = F.scalar_div
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def test_f1(x, y):
|
def test_f1(x, y):
|
||||||
|
@ -774,7 +774,7 @@ def test_incorporate_getitem(tag):
|
||||||
def test_incorporate_getitem_through_switch(tag):
|
def test_incorporate_getitem_through_switch(tag):
|
||||||
""" test_incorporate_getitem_through_switch """
|
""" test_incorporate_getitem_through_switch """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = F.scalar_gt
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(x, y):
|
def before(x, y):
|
||||||
|
@ -834,7 +834,7 @@ def test_incorporate_call_through_switch(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
f1 = Primitive('f1')
|
f1 = Primitive('f1')
|
||||||
f2 = Primitive('f2')
|
f2 = Primitive('f2')
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = F.scalar_gt
|
||||||
identity = Primitive('identity')
|
identity = Primitive('identity')
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
|
@ -869,7 +869,7 @@ def test_incorporate_call_through_switch(tag):
|
||||||
def test_float_tuple_getitem_through_switch(tag):
|
def test_float_tuple_getitem_through_switch(tag):
|
||||||
""" test_float_tuple_getitem_through_switch """
|
""" test_float_tuple_getitem_through_switch """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = F.scalar_gt
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(x, y):
|
def before(x, y):
|
||||||
|
@ -931,7 +931,7 @@ def test_convert_switch_ops(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
ge_switch = Primitive('GeSwitch')
|
ge_switch = Primitive('GeSwitch')
|
||||||
merge = Primitive('Merge')
|
merge = Primitive('Merge')
|
||||||
add = Primitive(Constants.kScalarAdd)
|
add = F.scalar_add
|
||||||
neg = Primitive('Neg')
|
neg = Primitive('Neg')
|
||||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
make_tuple = Primitive('MakeTuple')
|
make_tuple = Primitive('MakeTuple')
|
||||||
|
|
|
@ -13,12 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import Primitive
|
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import _constants as Constants
|
|
||||||
|
|
||||||
scala_add = Primitive(Constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
|
|
||||||
|
|
||||||
class AddNet(nn.Cell):
|
class AddNet(nn.Cell):
|
||||||
|
|
|
@ -13,8 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" multi_relu_case """
|
""" multi_relu_case """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import _constants as Constants
|
|
||||||
|
|
||||||
|
|
||||||
# Test user define ops
|
# Test user define ops
|
||||||
|
@ -22,7 +21,7 @@ def get_test_ops_fn():
|
||||||
return test_ops_f
|
return test_ops_f
|
||||||
|
|
||||||
|
|
||||||
scalar_mul = Primitive(Constants.kScalarMul)
|
scalar_mul = F.scalar_mul
|
||||||
|
|
||||||
|
|
||||||
def test_ops_f(x, y):
|
def test_ops_f(x, y):
|
||||||
|
|
|
@ -13,12 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" vm_test """
|
""" vm_test """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import _constants as Constants
|
|
||||||
|
|
||||||
scala_add = Primitive(Constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
scala_mul = Primitive(Constants.kScalarMul)
|
scala_mul = F.scalar_mul
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = F.scalar_gt
|
||||||
|
|
||||||
|
|
||||||
def ScalarAdd(x, y):
|
def ScalarAdd(x, y):
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test mul operation for dynamic sequence and variable integer in graph mode"""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import jit
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_scalar_div_and_mod():
|
||||||
|
"""
|
||||||
|
Feature: Constant scalar div and mod operation.
|
||||||
|
Description:
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
return 2/3, -2%3
|
||||||
|
|
||||||
|
ret1, ret2 = foo()
|
||||||
|
tol = 1e-6
|
||||||
|
assert np.abs(ret1 - 0.666666) < tol
|
||||||
|
assert np.abs(ret2 - 1) < tol
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_scalar_bitwise():
|
||||||
|
"""
|
||||||
|
Feature: Constant scalar bitwise operation.
|
||||||
|
Description:
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
return 3 & 1, True & 4, 4 | 3
|
||||||
|
|
||||||
|
ret1, ret2, ret3 = foo()
|
||||||
|
assert ret1 == 1
|
||||||
|
assert ret2 == 0
|
||||||
|
assert ret3 == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_scalar_div_and_mod():
|
||||||
|
"""
|
||||||
|
Feature: Variable scalar div and mod operation.
|
||||||
|
Description:
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
x = mutable(2)
|
||||||
|
y = mutable(3)
|
||||||
|
ret1 = x / y
|
||||||
|
ret2 = x % y
|
||||||
|
return isinstance(ret1, float), F.isconstant(ret2)
|
||||||
|
|
||||||
|
ret1, ret2 = foo()
|
||||||
|
assert ret1
|
||||||
|
assert not ret2
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_scalar_bitwise():
|
||||||
|
"""
|
||||||
|
Feature: Variable scalar bitwise operation.
|
||||||
|
Description:
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
x = mutable(2)
|
||||||
|
y = mutable(3)
|
||||||
|
ret1 = x & y
|
||||||
|
ret2 = x | y
|
||||||
|
return isinstance(ret1, int), F.isconstant(ret2)
|
||||||
|
|
||||||
|
ret1, ret2 = foo()
|
||||||
|
assert ret1
|
||||||
|
assert not ret2
|
|
@ -142,9 +142,8 @@ def test_bitwise_operator_error_float_input():
|
||||||
return res
|
return res
|
||||||
|
|
||||||
net = Net()
|
net = Net()
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError):
|
||||||
net()
|
net()
|
||||||
assert "Unsupported input type. For BitOr, only integer types are supported, but got" in str(err.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bitwise_operator_error_too_large_number():
|
def test_bitwise_operator_error_too_large_number():
|
||||||
|
|
|
@ -17,8 +17,6 @@ import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import jit
|
from mindspore.common.api import jit
|
||||||
from mindspore.ops import Primitive
|
|
||||||
from mindspore.ops import _constants
|
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -29,7 +27,7 @@ from ...ut_filter import non_graph_engine
|
||||||
|
|
||||||
|
|
||||||
tensor_add = P.Add()
|
tensor_add = P.Add()
|
||||||
scala_add = Primitive(_constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
add = C.MultitypeFuncGraph('add')
|
add = C.MultitypeFuncGraph('add')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,16 +18,15 @@ import numpy as np
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import jit
|
from mindspore.common.api import jit
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.ops import Primitive
|
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import _constants
|
from mindspore.ops import functional as F
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from ...ut_filter import non_graph_engine
|
from ...ut_filter import non_graph_engine
|
||||||
|
|
||||||
tensor_add = P.Add()
|
tensor_add = P.Add()
|
||||||
op_add = P.AddN()
|
op_add = P.AddN()
|
||||||
scala_add = Primitive(_constants.kScalarAdd)
|
scala_add = F.scalar_add
|
||||||
add = C.MultitypeFuncGraph('add')
|
add = C.MultitypeFuncGraph('add')
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue