!11982 Change tuple_getitem to TupleGetItem merge from r1.1 to master

From: @liangzhibo
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-02 19:06:39 +08:00 committed by Gitee
commit a616196586
77 changed files with 252 additions and 161 deletions

View File

@ -21,32 +21,32 @@ import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y): def ScalarAdd(x, y):
"""Implement `scalar_add`.""" """Implement `scalar_add`."""
return x + y return x + y
def scalar_mul(x, y): def ScalarMul(x, y):
"""Implement `scalar_mul`.""" """Implement `scalar_mul`."""
return x * y return x * y
def scalar_mod(x, y): def ScalarMod(x, y):
"""Implement `scalar_mul`.""" """Implement `scalar_mul`."""
return x % y return x % y
def scalar_sub(x, y): def ScalarSub(x, y):
"""Implement `scalar_sub`.""" """Implement `scalar_sub`."""
return x - y return x - y
def scalar_usub(x): def ScalarUsub(x):
"""Implement `scalar_usub`.""" """Implement `scalar_usub`."""
return -x return -x
def tuple_getitem(x, index): def TupleGetItem(x, index):
"""Implement `tuple_getitem`.""" """Implement `tuple_getitem`."""
if isinstance(x, Tensor): if isinstance(x, Tensor):
x = x.asnumpy() x = x.asnumpy()

View File

@ -15,6 +15,7 @@
*/ */
#include "frontend/operator/prim_to_function.h" #include "frontend/operator/prim_to_function.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
// namespace to support prim related definition // namespace to support prim related definition
@ -25,31 +26,31 @@ PrimToFunction::PrimToFunction()
{"bool_not", kPrimTypeOneArg}, {"bool_not", kPrimTypeOneArg},
{"scalar_cos", kPrimTypeOneArg}, {"scalar_cos", kPrimTypeOneArg},
{"scalar_exp", kPrimTypeOneArg}, {"scalar_exp", kPrimTypeOneArg},
{"scalar_floor", kPrimTypeOneArg}, {kScalarFloor, kPrimTypeOneArg},
{"scalar_log", kPrimTypeOneArg}, {"scalar_log", kPrimTypeOneArg},
{"scalar_sin", kPrimTypeOneArg}, {"scalar_sin", kPrimTypeOneArg},
{"scalar_tan", kPrimTypeOneArg}, {"scalar_tan", kPrimTypeOneArg},
{"scalar_trunc", kPrimTypeOneArg}, {kScalarTrunc, kPrimTypeOneArg},
{"typeof", kPrimTypeOneArg}, {"typeof", kPrimTypeOneArg},
{"scalar_uadd", kPrimTypeOneArg}, {kScalarUadd, kPrimTypeOneArg},
{"scalar_usub", kPrimTypeOneArg}, {kScalarUsub, kPrimTypeOneArg},
// TWO_ARGS prim // TWO_ARGS prim
{"scalar_add", kPrimTypeTwoArgs}, {kScalarAdd, kPrimTypeTwoArgs},
{"bool_and", kPrimTypeTwoArgs}, {"bool_and", kPrimTypeTwoArgs},
{"bool_eq", kPrimTypeTwoArgs}, {"bool_eq", kPrimTypeTwoArgs},
{"bool_or", kPrimTypeTwoArgs}, {"bool_or", kPrimTypeTwoArgs},
{"scalar_div", kPrimTypeTwoArgs}, {kScalarDiv, kPrimTypeTwoArgs},
{"scalar_eq", kPrimTypeTwoArgs}, {"scalar_eq", kPrimTypeTwoArgs},
{"scalar_ge", kPrimTypeTwoArgs}, {"scalar_ge", kPrimTypeTwoArgs},
{"scalar_gt", kPrimTypeTwoArgs}, {"scalar_gt", kPrimTypeTwoArgs},
{"scalar_le", kPrimTypeTwoArgs}, {"scalar_le", kPrimTypeTwoArgs},
{"scalar_lt", kPrimTypeTwoArgs}, {"scalar_lt", kPrimTypeTwoArgs},
{"scalar_ne", kPrimTypeTwoArgs}, {"scalar_ne", kPrimTypeTwoArgs},
{"scalar_mod", kPrimTypeTwoArgs}, {kScalarMod, kPrimTypeTwoArgs},
{"scalar_mul", kPrimTypeTwoArgs}, {kScalarMul, kPrimTypeTwoArgs},
{"scalar_pow", kPrimTypeTwoArgs}, {kScalarPow, kPrimTypeTwoArgs},
{"scalar_sub", kPrimTypeTwoArgs}, {kScalarSub, kPrimTypeTwoArgs},
{"scalar_floordiv", kPrimTypeTwoArgs}}) {} {kScalarFloordiv, kPrimTypeTwoArgs}}) {}
bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
bool result = false; bool result = false;

View File

@ -18,6 +18,7 @@
#include <string> #include <string>
#include "base/core_ops.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/meta_tensor.h" #include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
@ -306,7 +307,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == TUPLE_GETITEM) { if (prim->name() == prim::kTupleGetItem) {
*out_index = GetTupleGetItemIndex(cnode); *out_index = GetTupleGetItemIndex(cnode);
// find tuple_get_item's previous node // find tuple_get_item's previous node
auto pre_node = cnode->input(1); auto pre_node = cnode->input(1);

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
constexpr size_t PRELU_INPUTS_SIZE = 2; constexpr size_t PRELU_INPUTS_SIZE = 2;
@ -320,7 +322,6 @@ constexpr char KStridedSlice[] = "StridedSlice";
constexpr char UNIQUE[] = "Unique"; constexpr char UNIQUE[] = "Unique";
// Parallel don't care // Parallel don't care
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
constexpr char STRING_EQUAL[] = "string_equal"; constexpr char STRING_EQUAL[] = "string_equal";
constexpr char MAKE_TUPLE[] = "make_tuple"; constexpr char MAKE_TUPLE[] = "make_tuple";
constexpr char MAKE_LIST[] = "make_list"; constexpr char MAKE_LIST[] = "make_list";

View File

@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "ir/value.h" #include "ir/value.h"
#include "base/core_ops.h"
#include "frontend/parallel/device_matrix.h" #include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
@ -206,8 +207,8 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)}); auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()}); auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()});
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)}); auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(0)});
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)}); auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(1)});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1}); auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast}); auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast});

View File

@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include "base/core_ops.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
@ -599,8 +600,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
size_t output_index = 0; size_t output_index = 0;
bool bool_result = bool bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); (prev_prim->name() == DEPEND);
while (bool_result) { while (bool_result) {
if (IsAutoParallelCareNode(prev_cnode)) { if (IsAutoParallelCareNode(prev_cnode)) {
auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
@ -639,7 +640,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
edge_count++; edge_count++;
break; break;
} else if (prev_prim->name() == TUPLE_GETITEM) { } else if (prev_prim->name() == prim::kTupleGetItem) {
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
// this 'tuple_getitem' // this 'tuple_getitem'
MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
@ -672,8 +673,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
<< "and creating an edge between the Operator before " << "and creating an edge between the Operator before "
<< "'depend' and the Operator after 'depend'."; << "'depend' and the Operator after 'depend'.";
} }
bool_result = bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); (prev_prim->name() == DEPEND);
} }
} }
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
@ -960,13 +961,13 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) { CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
auto prev_cnode = cnode->input(1)->cast<CNodePtr>(); auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) { if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
return nullptr; return nullptr;
} }
auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>(); prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) { if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
return nullptr; return nullptr;

View File

@ -27,6 +27,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "base/core_ops.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "frontend/parallel/auto_parallel/graph_costmodel.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h"
@ -311,7 +312,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
} }
PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0)); PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
MS_EXCEPTION_IF_NULL(value_node_prim); MS_EXCEPTION_IF_NULL(value_node_prim);
if (value_node_prim->name() == TUPLE_GETITEM) { if (value_node_prim->name() == prim::kTupleGetItem) {
if (uses_set.size() > 1) { if (uses_set.size() > 1) {
MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
} }
@ -409,7 +410,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim,
const OperatorInfoPtr &distribute_operator) { const OperatorInfoPtr &distribute_operator) {
TensorInfo tensorinfo_in; TensorInfo tensorinfo_in;
if (middle_prim->name() == TUPLE_GETITEM) { if (middle_prim->name() == prim::kTupleGetItem) {
auto value_node = middle_node->input(2)->cast<ValueNodePtr>(); auto value_node = middle_node->input(2)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
size_t index_s = LongToSize(GetValue<int64_t>(value_node->value())); size_t index_s = LongToSize(GetValue<int64_t>(value_node->value()));
@ -603,7 +604,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
MS_EXCEPTION_IF_NULL(current_value); MS_EXCEPTION_IF_NULL(current_value);
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(current_prim); MS_EXCEPTION_IF_NULL(current_prim);
insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node); insert_node_new = ((current_prim->name() == prim::kTupleGetItem) ? node : insert_node);
} else { } else {
insert_node_new = insert_node; insert_node_new = insert_node;
} }
@ -2117,7 +2118,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == TUPLE_GETITEM) { if (prim->name() == prim::kTupleGetItem) {
auto tuple_index = GetTupleGetItemIndex(cnode); auto tuple_index = GetTupleGetItemIndex(cnode);
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index)); auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
if (!layout_ptr) { if (!layout_ptr) {
@ -2234,7 +2235,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
} }
// return -> tuple_getitem -> loss // return -> tuple_getitem -> loss
if (current_prim->name() == TUPLE_GETITEM) { if (current_prim->name() == prim::kTupleGetItem) {
auto tuple_index = GetTupleGetItemIndex(pre_cnode); auto tuple_index = GetTupleGetItemIndex(pre_cnode);
AnfNodePtr pre_pre_node = pre_cnode->input(1); AnfNodePtr pre_pre_node = pre_cnode->input(1);
MS_EXCEPTION_IF_NULL(pre_pre_node); MS_EXCEPTION_IF_NULL(pre_pre_node);
@ -2491,7 +2492,7 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
} }
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>(); auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetItem)) {
continue; continue;
} }

View File

@ -21,6 +21,7 @@
#include <stack> #include <stack>
#include "utils/utils.h" #include "utils/utils.h"
#include "base/core_ops.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
@ -631,7 +632,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
} }
TraceOutput(c->input(1)); TraceOutput(c->input(1));
} else if (name == "tuple_getitem") { } else if (name == prim::kTupleGetItem) {
TraceOutputFromTupleGetItem(anf_out); TraceOutputFromTupleGetItem(anf_out);
} else { } else {
// add outputs; // add outputs;
@ -1014,7 +1015,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
if (it != out_handle_cache_.end()) { if (it != out_handle_cache_.end()) {
int ret = adpt->setInput(src, SizeToInt(i), it->second); int ret = adpt->setInput(src, SizeToInt(i), it->second);
if (ret == 0) { if (ret == 0) {
if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kTupleGetItem) {
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
<< ":" << i << endl; << ":" << i << endl;
} else if (pred->isa<Parameter>()) { } else if (pred->isa<Parameter>()) {
@ -1538,7 +1539,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
} }
// As for nodes with multi outputs, convert tuple_getitem to OutHandle // As for nodes with multi outputs, convert tuple_getitem to OutHandle
if (name == "tuple_getitem") { if (name == prim::kTupleGetItem) {
ConvertTupleGetItem(node); ConvertTupleGetItem(node);
return false; return false;
} }

View File

@ -26,19 +26,33 @@
namespace mindspore { namespace mindspore {
namespace prim { namespace prim {
constexpr auto kGather = "Gather"; constexpr auto kGather = "Gather";
// Arithmetic
constexpr auto kScalarAdd = "ScalarAdd";
constexpr auto kScalarSub = "ScalarSub";
constexpr auto kScalarMul = "ScalarMul";
constexpr auto kScalarDiv = "ScalarDiv";
constexpr auto kScalarFloordiv = "ScalarFloordiv";
constexpr auto kScalarMod = "ScalarMod";
constexpr auto kScalarPow = "ScalarPow";
constexpr auto kScalarTrunc = "ScalarTrunc";
constexpr auto kScalarFloor = "ScalarFloor";
constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kTupleGetItem = "TupleGetItem";
// Here list all primitives used in backend or some special primitives used by core. // Here list all primitives used in backend or some special primitives used by core.
// Arithmetic // Arithmetic
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>(kScalarAdd);
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>(kScalarSub);
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>(kScalarMul);
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>(kScalarDiv);
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>(kScalarFloordiv);
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>(kScalarMod);
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>(kScalarPow);
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>(kScalarTrunc);
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor"); inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>(kScalarFloor);
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd"); inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>(kScalarUadd);
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub"); inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>(kScalarUsub);
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp"); inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log"); inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
@ -295,7 +309,7 @@ inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice"); inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem"); inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>(kTupleGetItem);
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem"); inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem"); inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem"); inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");

View File

@ -19,9 +19,11 @@
#include <set> #include <set>
#include <string> #include <string>
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
// clang-format off // clang-format off
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem", static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array", "list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem", "make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem",

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""bprop primitives""" """bprop primitives"""
from mindspore.ops import _constants
from ..operations import _grad_ops as G from ..operations import _grad_ops as G
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
@ -50,31 +51,31 @@ def bprop_relu_grad_grad(x, y, out, dout):
return dy, F.zeros_like(y) return dy, F.zeros_like(y)
@bprops.register("scalar_add") @bprops.register(_constants.kScalarAdd)
def bprop_scalar_add(x, y, out, dout): def bprop_scalar_add(x, y, out, dout):
"""Backpropagator for primitive `scalar_add`.""" """Backpropagator for primitive `scalar_add`."""
return dout, dout return dout, dout
@bprops.register("scalar_mul") @bprops.register(_constants.kScalarMul)
def bprop_scalar_mul(x, y, out, dout): def bprop_scalar_mul(x, y, out, dout):
"""Backpropagator for primitive `scalar_mul`.""" """Backpropagator for primitive `scalar_mul`."""
return dout*y, dout*x return dout*y, dout*x
@bprops.register("scalar_sub") @bprops.register(_constants.kScalarSub)
def bprop_scalar_sub(x, y, out, dout): def bprop_scalar_sub(x, y, out, dout):
"""Backpropagator for primitive `scalar_sub`.""" """Backpropagator for primitive `scalar_sub`."""
return dout, -dout return dout, -dout
@bprops.register("scalar_div") @bprops.register(_constants.kScalarDiv)
def bprop_scalar_div(x, y, out, dout): def bprop_scalar_div(x, y, out, dout):
"""Backpropagator for primitive `scalar_div`.""" """Backpropagator for primitive `scalar_div`."""
return dout/y, (-dout) * (out/y) return dout/y, (-dout) * (out/y)
@bprops.register("scalar_pow") @bprops.register(_constants.kScalarPow)
def bprop_scalar_pow(x, y, out, dout): def bprop_scalar_pow(x, y, out, dout):
"""Backpropagator for primitive `scalar_pow`.""" """Backpropagator for primitive `scalar_pow`."""
return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out) return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
@ -86,13 +87,13 @@ def bprop_scalar_exp(x, out, dout):
return (dout * out,) return (dout * out,)
@bprops.register("scalar_uadd") @bprops.register(_constants.kScalarUadd)
def bprop_scalar_uadd(x, out, dout): def bprop_scalar_uadd(x, out, dout):
"""Backpropagator for primitive `scalar_uadd`.""" """Backpropagator for primitive `scalar_uadd`."""
return (dout,) return (dout,)
@bprops.register("scalar_usub") @bprops.register(_constants.kScalarUsub)
def bprop_scalar_usub(x, out, dout): def bprop_scalar_usub(x, out, dout):
"""Backpropagator for primitive `scalar_usub`.""" """Backpropagator for primitive `scalar_usub`."""
return (-dout,) return (-dout,)
@ -140,7 +141,7 @@ def bprop_scalar_cast(x, t, out, dout):
return F.scalar_cast(dout, F.typeof(x)), t return F.scalar_cast(dout, F.typeof(x)), t
@bprops.register("tuple_getitem") @bprops.register(_constants.kTupleGetItem)
def bprop_tuple_getitem(data, idx, out, dout): def bprop_tuple_getitem(data, idx, out, dout):
"""Backpropagator for primitive `tuple_getitem`.""" """Backpropagator for primitive `tuple_getitem`."""
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)

View File

@ -18,11 +18,11 @@
"""The names of functional part are summarized here.""" """The names of functional part are summarized here."""
from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants
from .primitive import Primitive from .primitive import Primitive
from . import operations as P from . import operations as P
from .operations import _grad_ops from .operations import _grad_ops
typeof = Primitive('typeof') typeof = Primitive('typeof')
hastype = Primitive('hastype') hastype = Primitive('hastype')
cast = P.Cast() cast = P.Cast()
@ -96,7 +96,7 @@ depend = P.Depend()
identity = P.identity() identity = P.identity()
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(_constants.kTupleGetItem)
list_getitem = Primitive('list_getitem') list_getitem = Primitive('list_getitem')
list_setitem = Primitive('list_setitem') list_setitem = Primitive('list_setitem')
dict_getitem = Primitive('dict_getitem') dict_getitem = Primitive('dict_getitem')
@ -114,22 +114,22 @@ tuple_equal = Primitive("tuple_equal")
list_equal = Primitive("list_equal") list_equal = Primitive("list_equal")
make_ref = Primitive("make_ref") make_ref = Primitive("make_ref")
scalar_add = Primitive('scalar_add') scalar_add = Primitive(_constants.kScalarAdd)
scalar_mul = Primitive('scalar_mul') scalar_mul = Primitive(_constants.kScalarMul)
scalar_sub = Primitive('scalar_sub') scalar_sub = Primitive(_constants.kScalarSub)
scalar_div = Primitive('scalar_div') scalar_div = Primitive(_constants.kScalarDiv)
scalar_floordiv = Primitive('scalar_floordiv') scalar_floordiv = Primitive(_constants.kScalarFloordiv)
scalar_log = Primitive('scalar_log') scalar_log = Primitive('scalar_log')
scalar_pow = Primitive('scalar_pow') scalar_pow = Primitive(_constants.kScalarPow)
scalar_gt = Primitive('scalar_gt') scalar_gt = Primitive('scalar_gt')
scalar_ge = Primitive('scalar_ge') scalar_ge = Primitive('scalar_ge')
scalar_le = Primitive('scalar_le') scalar_le = Primitive('scalar_le')
scalar_lt = Primitive('scalar_lt') scalar_lt = Primitive('scalar_lt')
scalar_eq = Primitive('scalar_eq') scalar_eq = Primitive('scalar_eq')
scalar_ne = Primitive('scalar_ne') scalar_ne = Primitive('scalar_ne')
scalar_uadd = Primitive('scalar_uadd') scalar_uadd = Primitive(_constants.kScalarUadd)
scalar_usub = Primitive('scalar_usub') scalar_usub = Primitive(_constants.kScalarUsub)
scalar_mod = Primitive('scalar_mod') scalar_mod = Primitive(_constants.kScalarMod)
string_eq = Primitive('string_equal') string_eq = Primitive('string_equal')
string_concat = Primitive('string_concat') string_concat = Primitive('string_concat')
bool_not = Primitive("bool_not") bool_not = Primitive("bool_not")

View File

@ -21,6 +21,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
@ -32,7 +33,7 @@ class TestAnf : public UT::Common {
}; };
TEST_F(TestAnf, test_ValueNode) { TEST_F(TestAnf, test_ValueNode) {
auto prim = std::make_shared<Primitive>("scalar_add"); auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
ValueNodePtr c = NewValueNode(prim); ValueNodePtr c = NewValueNode(prim);
ASSERT_EQ(c->isa<ValueNode>(), true); ASSERT_EQ(c->isa<ValueNode>(), true);
ASSERT_EQ(IsValueNode<Primitive>(c), true); ASSERT_EQ(IsValueNode<Primitive>(c), true);

View File

@ -24,6 +24,7 @@
#include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse.h"
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "debug/draw.h" #include "debug/draw.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
class TestCloner : public UT::Common { class TestCloner : public UT::Common {
@ -89,7 +90,7 @@ TEST_F(TestCloner, test_clone_simple) {
Cloner cl2(gs); Cloner cl2(gs);
auto g3 = cl2[g]; auto g3 = cl2[g];
std::vector<Primitive> results = {Primitive("scalar_add"), Primitive("scalar_mul"), Primitive("return")}; std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("return")};
AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return())); AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
common = d1 & d3; common = d1 & d3;
for (auto& x : common) { for (auto& x : common) {

View File

@ -22,6 +22,7 @@
#include "pybind_api/ir/primitive_py.h" #include "pybind_api/ir/primitive_py.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace prim { namespace prim {
@ -34,52 +35,52 @@ class TestOps : public UT::Common {
// Arithmetic // Arithmetic
TEST_F(TestOps, ScalarAddTest) { TEST_F(TestOps, ScalarAddTest) {
auto prim = std::make_shared<Primitive>("scalar_add"); auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
ASSERT_EQ(prim->name(), kPrimScalarAdd->name()); ASSERT_EQ(prim->name(), kPrimScalarAdd->name());
} }
TEST_F(TestOps, ScalarSubTest) { TEST_F(TestOps, ScalarSubTest) {
auto prim = std::make_shared<Primitive>("scalar_sub"); auto prim = std::make_shared<Primitive>(prim::kScalarSub);
ASSERT_EQ(prim->name(), kPrimScalarSub->name()); ASSERT_EQ(prim->name(), kPrimScalarSub->name());
} }
TEST_F(TestOps, ScalarMulTest) { TEST_F(TestOps, ScalarMulTest) {
auto prim = std::make_shared<Primitive>("scalar_mul"); auto prim = std::make_shared<Primitive>(prim::kScalarMul);
ASSERT_EQ(prim->name(), kPrimScalarMul->name()); ASSERT_EQ(prim->name(), kPrimScalarMul->name());
} }
TEST_F(TestOps, ScalarDivTest) { TEST_F(TestOps, ScalarDivTest) {
auto prim = std::make_shared<Primitive>("scalar_div"); auto prim = std::make_shared<Primitive>(prim::kScalarDiv);
ASSERT_EQ(prim->name(), kPrimScalarDiv->name()); ASSERT_EQ(prim->name(), kPrimScalarDiv->name());
} }
TEST_F(TestOps, ScalarModTest) { TEST_F(TestOps, ScalarModTest) {
auto prim = std::make_shared<Primitive>("scalar_mod"); auto prim = std::make_shared<Primitive>(prim::kScalarMod);
ASSERT_EQ(prim->name(), kPrimScalarMod->name()); ASSERT_EQ(prim->name(), kPrimScalarMod->name());
} }
TEST_F(TestOps, ScalarPowTest) { TEST_F(TestOps, ScalarPowTest) {
auto prim = std::make_shared<Primitive>("scalar_pow"); auto prim = std::make_shared<Primitive>(prim::kScalarPow);
ASSERT_EQ(prim->name(), kPrimScalarPow->name()); ASSERT_EQ(prim->name(), kPrimScalarPow->name());
} }
TEST_F(TestOps, ScalarTruncTest) { TEST_F(TestOps, ScalarTruncTest) {
auto prim = std::make_shared<Primitive>("scalar_trunc"); auto prim = std::make_shared<Primitive>(prim::kScalarTrunc);
ASSERT_EQ(prim->name(), kPrimScalarTrunc->name()); ASSERT_EQ(prim->name(), kPrimScalarTrunc->name());
} }
TEST_F(TestOps, ScalarFloorTest) { TEST_F(TestOps, ScalarFloorTest) {
auto prim = std::make_shared<Primitive>("scalar_floor"); auto prim = std::make_shared<Primitive>(prim::kScalarFloor);
ASSERT_EQ(prim->name(), kPrimScalarFloor->name()); ASSERT_EQ(prim->name(), kPrimScalarFloor->name());
} }
TEST_F(TestOps, ScalarUaddTest) { TEST_F(TestOps, ScalarUaddTest) {
auto prim = std::make_shared<Primitive>("scalar_uadd"); auto prim = std::make_shared<Primitive>(prim::kScalarUadd);
ASSERT_EQ(prim->name(), kPrimScalarUadd->name()); ASSERT_EQ(prim->name(), kPrimScalarUadd->name());
} }
TEST_F(TestOps, ScalarUsubTest) { TEST_F(TestOps, ScalarUsubTest) {
auto prim = std::make_shared<Primitive>("scalar_usub"); auto prim = std::make_shared<Primitive>(prim::kScalarUsub);
ASSERT_EQ(prim->name(), kPrimScalarUsub->name()); ASSERT_EQ(prim->name(), kPrimScalarUsub->name());
} }
@ -187,7 +188,7 @@ TEST_F(TestOps, MakeRecordTest) {
} }
TEST_F(TestOps, TupleGetItemTest) { TEST_F(TestOps, TupleGetItemTest) {
auto prim = std::make_shared<Primitive>("tuple_getitem"); auto prim = std::make_shared<Primitive>(kTupleGetItem);
ASSERT_EQ(prim->name(), kPrimTupleGetItem->name()); ASSERT_EQ(prim->name(), kPrimTupleGetItem->name());
} }

View File

@ -22,6 +22,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "frontend/operator/prim_to_function.h" #include "frontend/operator/prim_to_function.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace prim { namespace prim {
@ -33,7 +34,7 @@ class TestPrimFunc : public UT::Common {
}; };
TEST_F(TestPrimFunc, ScalarAddTest) { TEST_F(TestPrimFunc, ScalarAddTest) {
auto prim = std::make_shared<Primitive>("scalar_add"); auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
FunctionPtr func = nullptr; FunctionPtr func = nullptr;
PrimToFunction::GetInstance().GetFunction(prim, &func); PrimToFunction::GetInstance().GetFunction(prim, &func);

View File

@ -27,6 +27,7 @@
#include "debug/draw.h" #include "debug/draw.h"
#include "ir/tensor.h" #include "ir/tensor.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
@ -154,7 +155,7 @@ TEST_F(TestPrim, test_list_map) {
AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false); AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false); AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2})); auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add); AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
args_spec_list.push_back(abstract_func); args_spec_list.push_back(abstract_func);
@ -179,7 +180,7 @@ TEST_F(TestPrim, test_list_reduce) {
AbstractBasePtr abstract_v1 = FromValue(v1, false); AbstractBasePtr abstract_v1 = FromValue(v1, false);
AbstractBasePtr abstract_v2 = FromValue(v1, false); AbstractBasePtr abstract_v2 = FromValue(v1, false);
auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2})); auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add); AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
args_spec_list.push_back(abstract_func); args_spec_list.push_back(abstract_func);

View File

@ -27,6 +27,7 @@
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "utils/misc.h" #include "utils/misc.h"
#include "debug/draw.h" #include "debug/draw.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
@ -95,7 +96,7 @@ void TestSpecializeGraph::SetUp() {
// build func_graph beta // build func_graph beta
ParameterPtr x1 = graph_beta_->add_parameter(); ParameterPtr x1 = graph_beta_->add_parameter();
inputs.clear(); inputs.clear();
inputs.push_back(NewValueNode(std::make_shared<Primitive>("scalar_add"))); inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd)));
inputs.push_back(x1); inputs.push_back(x1);
inputs.push_back(y); inputs.push_back(y);
CNodePtr cnode_add = graph_beta_->NewCNode(inputs); CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
@ -166,7 +167,7 @@ class MetaScalarAdd : public MetaFuncGraph {
FuncGraphPtr graph_g = std::make_shared<FuncGraph>(); FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
ParameterPtr x = graph_g->add_parameter(); ParameterPtr x = graph_g->add_parameter();
ParameterPtr y = graph_g->add_parameter(); ParameterPtr y = graph_g->add_parameter();
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim_scalar_add)); inputs.push_back(NewValueNode(prim_scalar_add));
inputs.push_back(x); inputs.push_back(x);

View File

@ -28,6 +28,7 @@
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "debug/draw.h" #include "debug/draw.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
@ -96,7 +97,7 @@ class MetaScalarAdd : public MetaFuncGraph {
FuncGraphPtr fg = std::make_shared<FuncGraph>(); FuncGraphPtr fg = std::make_shared<FuncGraph>();
ParameterPtr x = fg->add_parameter(); ParameterPtr x = fg->add_parameter();
ParameterPtr y = fg->add_parameter(); ParameterPtr y = fg->add_parameter();
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim_scalar_add)); inputs.push_back(NewValueNode(prim_scalar_add));
inputs.push_back(x); inputs.push_back(x);
@ -161,7 +162,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2); args_spec_list.push_back(abstract_v2);
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
@ -388,7 +389,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
args_spec.push_back(abstract_v1); args_spec.push_back(abstract_v1);
args_spec.push_back(abstract_v2); args_spec.push_back(abstract_v2);
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract(); AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
@ -418,7 +419,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
AbstractBasePtr base1 = FromValue(x1, false); AbstractBasePtr base1 = FromValue(x1, false);
AbstractBasePtr base2 = FromValue(x2, false); AbstractBasePtr base2 = FromValue(x2, false);
AbstractBasePtrList base_list = {base1, base2}; AbstractBasePtrList base_list = {base1, base2};
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract(); auto res = EvalOnePrim(std::make_shared<Primitive>(prim::kScalarAdd), base_list)->abstract();
MS_LOG(INFO) << "result spec: " << res->ToString(); MS_LOG(INFO) << "result spec: " << res->ToString();
AbstractBasePtr exp = FromValue(x3, false); AbstractBasePtr exp = FromValue(x3, false);
MS_LOG(INFO) << "result exp: " << exp->ToString(); MS_LOG(INFO) << "result exp: " << exp->ToString();

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
""" Test for GraphCloner """ """ Test for GraphCloner """
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
scala_add = Primitive('scalar_add') scala_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive('scalar_mul') scalar_mul = Primitive(Constants.kScalarMul)
def test_clone_simple(): def test_clone_simple():

View File

@ -18,9 +18,11 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from tests.ut.python.model.resnet import resnet50 from tests.ut.python.model.resnet import resnet50
scala_add = Primitive('scalar_add')
scala_add = Primitive(Constants.kScalarAdd)
@dataclass @dataclass

View File

@ -17,6 +17,7 @@ import numpy as np
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
@ -26,9 +27,9 @@ from mindspore.ops.operations import _grad_ops as G
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
scalar_add = Primitive('scalar_add') scalar_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive('scalar_mul') scalar_mul = Primitive(Constants.kScalarMul)
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
switch = Primitive('switch') switch = Primitive('switch')
@ -347,7 +348,7 @@ def test_inline_while(tag):
def test_cse(tag): def test_cse(tag):
""" test_cse """ """ test_cse """
fns = FnDict() fns = FnDict()
scalar_div = Primitive('scalar_div') scalar_div = Primitive(Constants.kScalarDiv)
@fns @fns
def test_f1(x, y): def test_f1(x, y):
@ -920,9 +921,9 @@ def test_convert_switch_ops(tag):
fns = FnDict() fns = FnDict()
ge_switch = Primitive('GeSwitch') ge_switch = Primitive('GeSwitch')
merge = Primitive('Merge') merge = Primitive('Merge')
add = Primitive('Add') add = Primitive(Constants.kScalarAdd)
neg = Primitive('Neg') neg = Primitive('Neg')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
@fns @fns

View File

@ -18,8 +18,9 @@ import mindspore.nn as nn
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
scala_add = Primitive('scalar_add') scala_add = Primitive(Constants.kScalarAdd)
@dataclass @dataclass

View File

@ -15,6 +15,8 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import _constants as Constants
Add = P.Add() Add = P.Add()
Sub = P.Sub() Sub = P.Sub()
@ -24,7 +26,7 @@ Sqrt = P.Sqrt()
Square = P.Square() Square = P.Square()
Assign = P.Assign() Assign = P.Assign()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
AdamApplyOne = Primitive('AdamApplyOne') AdamApplyOne = Primitive('AdamApplyOne')
AdamApplyOneAssign = Primitive('AdamApplyOneAssign') AdamApplyOneAssign = Primitive('AdamApplyOneAssign')

View File

@ -16,6 +16,7 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import _constants as Constants
mul = P.Mul() mul = P.Mul()
add = P.Add() add = P.Add()
@ -25,7 +26,7 @@ real_div = P.RealDiv()
sub = P.Sub() sub = P.Sub()
Assign = P.Assign() Assign = P.Assign()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign') adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
BatchNorm = P.BatchNorm() BatchNorm = P.BatchNorm()
BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV2 = Primitive('BNTrainingUpdateV2') BNTrainingUpdateV2 = Primitive('BNTrainingUpdateV2')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
BatchNormGradTraining = G.BatchNormGrad(is_training=True) BatchNormGradTraining = G.BatchNormGrad(is_training=True)
BatchNormGradInfer = G.BatchNormGrad(is_training=False) BatchNormGradInfer = G.BatchNormGrad(is_training=False)
BNInferGrad = Primitive('BNInferGrad') BNInferGrad = Primitive('BNInferGrad')

View File

@ -15,12 +15,13 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
batch_norm_grad = G.BatchNormGrad(is_training=True) batch_norm_grad = G.BatchNormGrad(is_training=True)
bn_training_update_grad = Primitive('BNTrainingUpdateGrad') bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad') bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,11 +15,12 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
batch_norm = P.BatchNorm(is_training=False) batch_norm = P.BatchNorm(is_training=False)
bn_infer = Primitive('BNInfer') bn_infer = Primitive('BNInfer')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,11 +15,12 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
batch_norm_grad = G.BatchNormGrad(is_training=False) batch_norm_grad = G.BatchNormGrad(is_training=False)
bn_infer_grad = Primitive('BNInferGrad') bn_infer_grad = Primitive('BNInferGrad')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,9 +15,10 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
bn_grad = G.BatchNormGrad(is_training=True) bn_grad = G.BatchNormGrad(is_training=True)
bn_grad1 = Primitive('BNGrad1') bn_grad1 = Primitive('BNGrad1')
bn_grad2 = Primitive('BNGrad2') bn_grad2 = Primitive('BNGrad2')

View File

@ -15,9 +15,10 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
bn = P.BatchNorm(is_training=True) bn = P.BatchNorm(is_training=True)
fused_bn1 = Primitive('FusedBN1') fused_bn1 = Primitive('FusedBN1')
fused_bn2 = Primitive('FusedBN2') fused_bn2 = Primitive('FusedBN2')

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
select = P.Select() select = P.Select()
maximum = P.Maximum() maximum = P.Maximum()
@ -21,7 +22,7 @@ sqrt = P.Sqrt()
greater = P.Greater() greater = P.Greater()
clip_by_norm_no_div_square_sum = Primitive('ClipByNormNoDivSum') clip_by_norm_no_div_square_sum = Primitive('ClipByNormNoDivSum')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,12 +14,13 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
maximum = P.Maximum() maximum = P.Maximum()
minimum = P.Minimum() minimum = P.Minimum()
clip_by_value = Primitive('ClipByValue') clip_by_value = Primitive('ClipByValue')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,13 +14,14 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
addn = P.AddN() addn = P.AddN()
mul = P.Mul() mul = P.Mul()
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
confusion_mul_grad = Primitive('ConfusionMulGrad') confusion_mul_grad = Primitive('ConfusionMulGrad')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
axis = 1 axis = 1

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P from mindspore.ops import operations as P
mul = P.Mul() mul = P.Mul()
@ -20,7 +21,7 @@ reduce_sum = P.ReduceSum(keep_dims=True)
sub = P.Sub() sub = P.Sub()
confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad') confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
axis = 2 axis = 2

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_get_item = Primitive("tuple_getitem") tuple_get_item = Primitive(Constants.kTupleGetItem)
LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0) LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0)
add = P.Add() add = P.Add()

View File

@ -14,13 +14,14 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
relu = P.ReLU() relu = P.ReLU()
relu_grad = Primitive('ReluGrad') relu_grad = Primitive('ReluGrad')
relu_v2 = Primitive('ReLUV2') relu_v2 = Primitive('ReLUV2')
relu_grad_v2 = Primitive('ReluGradV2') relu_grad_v2 = Primitive('ReluGradV2')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -17,12 +17,13 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import _constants as Constants
AssignSub = P.AssignSub() AssignSub = P.AssignSub()
Mul = P.Mul() Mul = P.Mul()
Sub = P.Sub() Sub = P.Sub()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
BatchNorm = P.BatchNorm() BatchNorm = P.BatchNorm()
Cast = P.Cast() Cast = P.Cast()
BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingReduce = Primitive('BNTrainingReduce')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
depend = P.Depend() depend = P.Depend()
addn = P.AddN() addn = P.AddN()
add = P.Add() add = P.Add()

View File

@ -15,12 +15,13 @@
import mindspore as ms import mindspore as ms
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P from mindspore.ops import operations as P
get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "") get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "")
memcpy_async = Primitive('memcpy_async') memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,12 +15,13 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
all_reduce = P.AllReduce() all_reduce = P.AllReduce()
broadcast = P.Broadcast(1) broadcast = P.Broadcast(1)
memcpy_async = Primitive('memcpy_async') memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
apply_momentun = P.ApplyMomentum() apply_momentun = P.ApplyMomentum()
control_depend = P.ControlDepend() control_depend = P.ControlDepend()
relu = P.ReLU() relu = P.ReLU()

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add() add = P.Add()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2) max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')

View File

@ -15,14 +15,15 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
# pylint: disable=unused-variable # pylint: disable=unused-variable
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add() add = P.Add()
allreduce = P.AllReduce() allreduce = P.AllReduce()
allreduce.add_prim_attr('fusion', 1) allreduce.add_prim_attr('fusion', 1)
make_tuple = Primitive('make_tuple') make_tuple = Primitive("make_tuple")
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1) conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
bn = P.FusedBatchNorm() bn = P.FusedBatchNorm()
relu = P.ReLU() relu = P.ReLU()

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Add = P.Add() Add = P.Add()
Mul = P.Mul() Mul = P.Mul()
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
Rsqrt = P.Rsqrt() Rsqrt = P.Rsqrt()
Sqrt = P.Sqrt() Sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
LambNextMV = Primitive('LambNextMV') LambNextMV = Primitive('LambNextMV')
class FnDict: class FnDict:

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Add = P.Add() Add = P.Add()
Mul = P.Mul() Mul = P.Mul()
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
Rsqrt = P.Rsqrt() Rsqrt = P.Rsqrt()
Sqrt = P.Sqrt() Sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
LambNextMVWithDecay = Primitive('LambNextMVWithDecay') LambNextMVWithDecay = Primitive('LambNextMVWithDecay')
class FnDict: class FnDict:

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
add = P.Add() add = P.Add()
mul = P.Mul() mul = P.Mul()
@ -21,7 +22,7 @@ real_div = P.RealDiv()
rsqrt = P.Rsqrt() rsqrt = P.Rsqrt()
sqrt = P.Sqrt() sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1') LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1')

View File

@ -14,13 +14,14 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Add = P.Add() Add = P.Add()
Mul = P.Mul() Mul = P.Mul()
Sqrt = P.Sqrt() Sqrt = P.Sqrt()
Square = P.Square() Square = P.Square()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
LambNextRight = Primitive('LambNextRight') LambNextRight = Primitive('LambNextRight')

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
select = P.Select() select = P.Select()
maximum = P.Maximum() maximum = P.Maximum()
@ -24,7 +25,7 @@ mul = P.Mul()
sub = P.Sub() sub = P.Sub()
lamb_update_with_lr = Primitive('LambUpdateWithLR') lamb_update_with_lr = Primitive('LambUpdateWithLR')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Sub = P.Sub() Sub = P.Sub()
Mul = P.Mul() Mul = P.Mul()
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
Select = P.Select() Select = P.Select()
Greater = P.Greater() Greater = P.Greater()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2') LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2')

View File

@ -14,12 +14,13 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
lars_v2 = Primitive('LarsV2') lars_v2 = Primitive('LarsV2')
square_sum_all = Primitive('SquareSumAll') square_sum_all = Primitive('SquareSumAll')
lars_v2_update = Primitive('LarsV2Update') lars_v2_update = Primitive('LarsV2Update')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,11 +14,12 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Add = P.Add() Add = P.Add()
Cast = P.Cast() Cast = P.Cast()
LayerNormBetaGammaBackprop = Primitive('LayerNormBetaGammaBackprop') LayerNormBetaGammaBackprop = Primitive('LayerNormBetaGammaBackprop')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,9 +15,10 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
layer_norm_grad = G.LayerNormGrad() layer_norm_grad = G.LayerNormGrad()
layer_norm_x_backprop = Primitive('LayerNormXBackprop') layer_norm_x_backprop = Primitive('LayerNormXBackprop')
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop') layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop')

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
depend = P.Depend() depend = P.Depend()
addn = P.AddN() addn = P.AddN()
add = P.Add() add = P.Add()

View File

@ -16,11 +16,12 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Mul = P.Mul() Mul = P.Mul()
ApplyMomentum = P.ApplyMomentum() ApplyMomentum = P.ApplyMomentum()
FusedMulApplyMomentum = Primitive('FusedMulApplyMomentum') FusedMulApplyMomentum = Primitive('FusedMulApplyMomentum')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
constant = Tensor(1.0, mstype.float32) constant = Tensor(1.0, mstype.float32)

View File

@ -14,12 +14,13 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
add = P.Add() add = P.Add()
mul = P.Mul() mul = P.Mul()
fused_mul_add = Primitive('FusedMulAdd') fused_mul_add = Primitive('FusedMulAdd')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -16,12 +16,13 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
addn = P.AddN() addn = P.AddN()
mul = P.Mul() mul = P.Mul()
fused_mul_addn = Primitive('FusedMulAddN') fused_mul_addn = Primitive('FusedMulAddN')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
scalar = Tensor(1.0, mstype.float32) scalar = Tensor(1.0, mstype.float32)

View File

@ -15,9 +15,10 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
reduce_min = P.ReduceMin(keep_dims=False) reduce_min = P.ReduceMin(keep_dims=False)
reduce_min1 = Primitive('ReduceMin') reduce_min1 = Primitive('ReduceMin')
reduce_min2 = Primitive('ReduceMin') reduce_min2 = Primitive('ReduceMin')

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add() add = P.Add()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2) max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
BatchNorm = P.BatchNorm(is_training=True) BatchNorm = P.BatchNorm(is_training=True)
BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3') BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')

View File

@ -14,13 +14,14 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Mul = P.Mul() Mul = P.Mul()
ReduceSum = P.ReduceSum(keep_dims=True) ReduceSum = P.ReduceSum(keep_dims=True)
Sub = P.Sub() Sub = P.Sub()
SoftmaxGradExt = Primitive('SoftmaxGradExt') SoftmaxGradExt = Primitive('SoftmaxGradExt')
MakeTuple = Primitive('make_tuple') MakeTuple = Primitive('make_tuple')
TupleGetItem = Primitive('tuple_getitem') TupleGetItem = Primitive(Constants.kTupleGetItem)
axes = (2, 3) axes = (2, 3)

View File

@ -15,10 +15,11 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
split = P.Split(0, 8) split = P.Split(0, 8)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
splitv = Primitive('SplitV') splitv = Primitive('SplitV')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
square = P.Square() square = P.Square()
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
square_sumv1 = Primitive('SquareSumV1') square_sumv1 = Primitive('SquareSumV1')

View File

@ -14,12 +14,13 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tensor_scatter_update = P.TensorScatterUpdate() tensor_scatter_update = P.TensorScatterUpdate()
tensor_move = Primitive('TensorMove') tensor_move = Primitive('TensorMove')
scatter_nd_update = Primitive('ScatterNdUpdate') scatter_nd_update = Primitive('ScatterNdUpdate')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -15,9 +15,10 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
TopK = P.TopK() TopK = P.TopK()
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
class FnDict: class FnDict:

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add() add = P.Add()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2) max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')

View File

@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add() add = P.Add()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2) max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')

View File

@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
unsorted_segment_sum = P.UnsortedSegmentSum() unsorted_segment_sum = P.UnsortedSegmentSum()
num_segments = 4 num_segments = 4
padding = Primitive('Padding') padding = Primitive('Padding')

View File

@ -15,12 +15,13 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
addn = P.AddN() addn = P.AddN()
add = P.Add() add = P.Add()
reshape = P.Reshape() reshape = P.Reshape()
cast = P.Cast() cast = P.Cast()
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive(Constants.kTupleGetItem)
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2) max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" multi_relu_case """ """ multi_relu_case """
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
# Test user define ops # Test user define ops
@ -21,7 +22,7 @@ def get_test_ops_fn():
return test_ops_f return test_ops_f
scalar_mul = Primitive('scalar_mul') scalar_mul = Primitive(Constants.kScalarMul)
def test_ops_f(x, y): def test_ops_f(x, y):

View File

@ -14,18 +14,19 @@
# ============================================================================ # ============================================================================
""" vm_test """ """ vm_test """
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
scala_add = Primitive('scalar_add') scala_add = Primitive(Constants.kScalarAdd)
scala_mul = Primitive('scalar_mul') scala_mul = Primitive(Constants.kScalarMul)
scalar_gt = Primitive('scalar_gt') scalar_gt = Primitive('scalar_gt')
def scalar_add(x, y): def ScalarAdd(x, y):
"""Implement `scalar_add`.""" """Implement `scalar_add`."""
return scala_add(x, y) return scala_add(x, y)
def scalar_mul(x, y): def ScalarMul(x, y):
"""Implement `scalar_mul`.""" """Implement `scalar_mul`."""
return scala_mul(x, y) return scala_mul(x, y)

View File

@ -23,6 +23,7 @@
#include "ir/manager.h" #include "ir/manager.h"
#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace validator { namespace validator {
@ -35,7 +36,7 @@ class TestValidator : public UT::Common {
}; };
TEST_F(TestValidator, ValidateOperation01) { TEST_F(TestValidator, ValidateOperation01) {
auto node = NewValueNode(std::make_shared<Primitive>("scalar_add")); auto node = NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd));
ValidateOperation(node); ValidateOperation(node);
// normally, the above statement should not exit, so expected the following statement execute // normally, the above statement should not exit, so expected the following statement execute
EXPECT_TRUE(true); EXPECT_TRUE(true);

View File

@ -31,6 +31,7 @@
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "utils/convert_utils_py.h" #include "utils/convert_utils_py.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace compile { namespace compile {
@ -46,7 +47,7 @@ class TestCompileSegmentRunner : public UT::Common {
}; };
TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
FuncGraphPtr g = get_py_fun_("scalar_add"); FuncGraphPtr g = get_py_fun_(prim::kScalarAdd);
// g was managed by local variable manager in get_py_fun_ and that manager will be freed as no reference. // g was managed by local variable manager in get_py_fun_ and that manager will be freed as no reference.
// so a new manager should be declared to make get_outputs() in segment_runner.cc happy. // so a new manager should be declared to make get_outputs() in segment_runner.cc happy.
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
@ -62,7 +63,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
} }
TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
FuncGraphPtr g = get_py_fun_("scalar_mul"); FuncGraphPtr g = get_py_fun_(prim::kScalarMul);
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
BackendPtr b = std::make_shared<Backend>("vm"); BackendPtr b = std::make_shared<Backend>("vm");

View File

@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args from mindspore.common.api import _generate_pip_args
@ -296,12 +297,12 @@ def test_imm_target():
pattern = Call(P.Softmax(), [x]) pattern = Call(P.Softmax(), [x])
imm = Imm(0) imm = Imm(0)
target_0 = Call("make_tuple", [pattern]) target_0 = Call("make_tuple", [pattern])
target = Call("tuple_getitem", [target_0, imm]) target = Call(Constants.kTupleGetItem, [target_0, imm])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_pass) unregiste_pass(softmax_pass)
assert "make_tuple" in transformed_repr assert "make_tuple" in transformed_repr
assert "tuple_getitem" in transformed_repr assert Constants.kTupleGetItem in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
def test_gen_new_parameter(): def test_gen_new_parameter():

View File

@ -18,6 +18,7 @@ import numpy as np
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -28,7 +29,7 @@ from ...ut_filter import non_graph_engine
tensor_add = P.Add() tensor_add = P.Add()
scala_add = Primitive('scalar_add') scala_add = Primitive(_constants.kScalarAdd)
add = C.MultitypeFuncGraph('add') add = C.MultitypeFuncGraph('add')

View File

@ -21,12 +21,13 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants
from mindspore import dtype as mstype from mindspore import dtype as mstype
from ...ut_filter import non_graph_engine from ...ut_filter import non_graph_engine
tensor_add = P.Add() tensor_add = P.Add()
op_add = P.AddN() op_add = P.AddN()
scala_add = Primitive('scalar_add') scala_add = Primitive(_constants.kScalarAdd)
add = C.MultitypeFuncGraph('add') add = C.MultitypeFuncGraph('add')