forked from mindspore-Ecosystem/mindspore
!11982 Change tuple_getitem to TupleGetItem merge from r1.1 to master
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
a616196586
|
@ -21,32 +21,32 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||||
|
|
||||||
|
|
||||||
def scalar_add(x, y):
|
def ScalarAdd(x, y):
|
||||||
"""Implement `scalar_add`."""
|
"""Implement `scalar_add`."""
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
def scalar_mul(x, y):
|
def ScalarMul(x, y):
|
||||||
"""Implement `scalar_mul`."""
|
"""Implement `scalar_mul`."""
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
def scalar_mod(x, y):
|
def ScalarMod(x, y):
|
||||||
"""Implement `scalar_mul`."""
|
"""Implement `scalar_mul`."""
|
||||||
return x % y
|
return x % y
|
||||||
|
|
||||||
|
|
||||||
def scalar_sub(x, y):
|
def ScalarSub(x, y):
|
||||||
"""Implement `scalar_sub`."""
|
"""Implement `scalar_sub`."""
|
||||||
return x - y
|
return x - y
|
||||||
|
|
||||||
|
|
||||||
def scalar_usub(x):
|
def ScalarUsub(x):
|
||||||
"""Implement `scalar_usub`."""
|
"""Implement `scalar_usub`."""
|
||||||
return -x
|
return -x
|
||||||
|
|
||||||
|
|
||||||
def tuple_getitem(x, index):
|
def TupleGetItem(x, index):
|
||||||
"""Implement `tuple_getitem`."""
|
"""Implement `tuple_getitem`."""
|
||||||
if isinstance(x, Tensor):
|
if isinstance(x, Tensor):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "frontend/operator/prim_to_function.h"
|
#include "frontend/operator/prim_to_function.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// namespace to support prim related definition
|
// namespace to support prim related definition
|
||||||
|
@ -25,31 +26,31 @@ PrimToFunction::PrimToFunction()
|
||||||
{"bool_not", kPrimTypeOneArg},
|
{"bool_not", kPrimTypeOneArg},
|
||||||
{"scalar_cos", kPrimTypeOneArg},
|
{"scalar_cos", kPrimTypeOneArg},
|
||||||
{"scalar_exp", kPrimTypeOneArg},
|
{"scalar_exp", kPrimTypeOneArg},
|
||||||
{"scalar_floor", kPrimTypeOneArg},
|
{kScalarFloor, kPrimTypeOneArg},
|
||||||
{"scalar_log", kPrimTypeOneArg},
|
{"scalar_log", kPrimTypeOneArg},
|
||||||
{"scalar_sin", kPrimTypeOneArg},
|
{"scalar_sin", kPrimTypeOneArg},
|
||||||
{"scalar_tan", kPrimTypeOneArg},
|
{"scalar_tan", kPrimTypeOneArg},
|
||||||
{"scalar_trunc", kPrimTypeOneArg},
|
{kScalarTrunc, kPrimTypeOneArg},
|
||||||
{"typeof", kPrimTypeOneArg},
|
{"typeof", kPrimTypeOneArg},
|
||||||
{"scalar_uadd", kPrimTypeOneArg},
|
{kScalarUadd, kPrimTypeOneArg},
|
||||||
{"scalar_usub", kPrimTypeOneArg},
|
{kScalarUsub, kPrimTypeOneArg},
|
||||||
// TWO_ARGS prim
|
// TWO_ARGS prim
|
||||||
{"scalar_add", kPrimTypeTwoArgs},
|
{kScalarAdd, kPrimTypeTwoArgs},
|
||||||
{"bool_and", kPrimTypeTwoArgs},
|
{"bool_and", kPrimTypeTwoArgs},
|
||||||
{"bool_eq", kPrimTypeTwoArgs},
|
{"bool_eq", kPrimTypeTwoArgs},
|
||||||
{"bool_or", kPrimTypeTwoArgs},
|
{"bool_or", kPrimTypeTwoArgs},
|
||||||
{"scalar_div", kPrimTypeTwoArgs},
|
{kScalarDiv, kPrimTypeTwoArgs},
|
||||||
{"scalar_eq", kPrimTypeTwoArgs},
|
{"scalar_eq", kPrimTypeTwoArgs},
|
||||||
{"scalar_ge", kPrimTypeTwoArgs},
|
{"scalar_ge", kPrimTypeTwoArgs},
|
||||||
{"scalar_gt", kPrimTypeTwoArgs},
|
{"scalar_gt", kPrimTypeTwoArgs},
|
||||||
{"scalar_le", kPrimTypeTwoArgs},
|
{"scalar_le", kPrimTypeTwoArgs},
|
||||||
{"scalar_lt", kPrimTypeTwoArgs},
|
{"scalar_lt", kPrimTypeTwoArgs},
|
||||||
{"scalar_ne", kPrimTypeTwoArgs},
|
{"scalar_ne", kPrimTypeTwoArgs},
|
||||||
{"scalar_mod", kPrimTypeTwoArgs},
|
{kScalarMod, kPrimTypeTwoArgs},
|
||||||
{"scalar_mul", kPrimTypeTwoArgs},
|
{kScalarMul, kPrimTypeTwoArgs},
|
||||||
{"scalar_pow", kPrimTypeTwoArgs},
|
{kScalarPow, kPrimTypeTwoArgs},
|
||||||
{"scalar_sub", kPrimTypeTwoArgs},
|
{kScalarSub, kPrimTypeTwoArgs},
|
||||||
{"scalar_floordiv", kPrimTypeTwoArgs}}) {}
|
{kScalarFloordiv, kPrimTypeTwoArgs}}) {}
|
||||||
|
|
||||||
bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
|
bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
|
||||||
bool result = false;
|
bool result = false;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
#include "ir/meta_tensor.h"
|
#include "ir/meta_tensor.h"
|
||||||
#include "pipeline/jit/parse/python_adapter.h"
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
@ -306,7 +307,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
|
||||||
}
|
}
|
||||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||||
if (prim->name() == TUPLE_GETITEM) {
|
if (prim->name() == prim::kTupleGetItem) {
|
||||||
*out_index = GetTupleGetItemIndex(cnode);
|
*out_index = GetTupleGetItemIndex(cnode);
|
||||||
// find tuple_get_item's previous node
|
// find tuple_get_item's previous node
|
||||||
auto pre_node = cnode->input(1);
|
auto pre_node = cnode->input(1);
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
constexpr size_t PRELU_INPUTS_SIZE = 2;
|
constexpr size_t PRELU_INPUTS_SIZE = 2;
|
||||||
|
@ -320,7 +322,6 @@ constexpr char KStridedSlice[] = "StridedSlice";
|
||||||
constexpr char UNIQUE[] = "Unique";
|
constexpr char UNIQUE[] = "Unique";
|
||||||
|
|
||||||
// Parallel don't care
|
// Parallel don't care
|
||||||
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
|
|
||||||
constexpr char STRING_EQUAL[] = "string_equal";
|
constexpr char STRING_EQUAL[] = "string_equal";
|
||||||
constexpr char MAKE_TUPLE[] = "make_tuple";
|
constexpr char MAKE_TUPLE[] = "make_tuple";
|
||||||
constexpr char MAKE_LIST[] = "make_list";
|
constexpr char MAKE_LIST[] = "make_list";
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
#include "frontend/parallel/device_matrix.h"
|
#include "frontend/parallel/device_matrix.h"
|
||||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
#include "frontend/parallel/strategy.h"
|
#include "frontend/parallel/strategy.h"
|
||||||
|
@ -206,8 +207,8 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)});
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)});
|
||||||
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
|
||||||
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()});
|
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()});
|
||||||
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)});
|
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(0)});
|
||||||
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)});
|
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(1)});
|
||||||
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1});
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1});
|
||||||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||||
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast});
|
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast});
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
#include "frontend/optimizer/opt.h"
|
#include "frontend/optimizer/opt.h"
|
||||||
#include "frontend/optimizer/optimizer.h"
|
#include "frontend/optimizer/optimizer.h"
|
||||||
#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
|
#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
|
||||||
|
@ -599,8 +600,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
|
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
|
||||||
size_t output_index = 0;
|
size_t output_index = 0;
|
||||||
|
|
||||||
bool bool_result =
|
bool bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
|
||||||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
|
(prev_prim->name() == DEPEND);
|
||||||
while (bool_result) {
|
while (bool_result) {
|
||||||
if (IsAutoParallelCareNode(prev_cnode)) {
|
if (IsAutoParallelCareNode(prev_cnode)) {
|
||||||
auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
|
auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
|
||||||
|
@ -639,7 +640,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
edge_count++;
|
edge_count++;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
} else if (prev_prim->name() == TUPLE_GETITEM) {
|
} else if (prev_prim->name() == prim::kTupleGetItem) {
|
||||||
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
|
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
|
||||||
// this 'tuple_getitem'
|
// this 'tuple_getitem'
|
||||||
MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
|
MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
|
||||||
|
@ -672,8 +673,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
<< "and creating an edge between the Operator before "
|
<< "and creating an edge between the Operator before "
|
||||||
<< "'depend' and the Operator after 'depend'.";
|
<< "'depend' and the Operator after 'depend'.";
|
||||||
}
|
}
|
||||||
bool_result =
|
bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
|
||||||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
|
(prev_prim->name() == DEPEND);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
|
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
|
||||||
|
@ -960,13 +961,13 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
|
||||||
|
|
||||||
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
|
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
|
||||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||||
if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) {
|
if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
|
||||||
auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
|
auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||||
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
|
auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
|
||||||
while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) {
|
while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
|
||||||
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
|
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
|
||||||
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "frontend/optimizer/optimizer.h"
|
#include "frontend/optimizer/optimizer.h"
|
||||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||||
|
@ -311,7 +312,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
||||||
}
|
}
|
||||||
PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
|
PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
|
||||||
MS_EXCEPTION_IF_NULL(value_node_prim);
|
MS_EXCEPTION_IF_NULL(value_node_prim);
|
||||||
if (value_node_prim->name() == TUPLE_GETITEM) {
|
if (value_node_prim->name() == prim::kTupleGetItem) {
|
||||||
if (uses_set.size() > 1) {
|
if (uses_set.size() > 1) {
|
||||||
MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
|
MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
|
||||||
}
|
}
|
||||||
|
@ -409,7 +410,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
|
||||||
TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim,
|
TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim,
|
||||||
const OperatorInfoPtr &distribute_operator) {
|
const OperatorInfoPtr &distribute_operator) {
|
||||||
TensorInfo tensorinfo_in;
|
TensorInfo tensorinfo_in;
|
||||||
if (middle_prim->name() == TUPLE_GETITEM) {
|
if (middle_prim->name() == prim::kTupleGetItem) {
|
||||||
auto value_node = middle_node->input(2)->cast<ValueNodePtr>();
|
auto value_node = middle_node->input(2)->cast<ValueNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
size_t index_s = LongToSize(GetValue<int64_t>(value_node->value()));
|
size_t index_s = LongToSize(GetValue<int64_t>(value_node->value()));
|
||||||
|
@ -603,7 +604,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
||||||
MS_EXCEPTION_IF_NULL(current_value);
|
MS_EXCEPTION_IF_NULL(current_value);
|
||||||
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(current_prim);
|
MS_EXCEPTION_IF_NULL(current_prim);
|
||||||
insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node);
|
insert_node_new = ((current_prim->name() == prim::kTupleGetItem) ? node : insert_node);
|
||||||
} else {
|
} else {
|
||||||
insert_node_new = insert_node;
|
insert_node_new = insert_node;
|
||||||
}
|
}
|
||||||
|
@ -2117,7 +2118,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||||
if (prim->name() == TUPLE_GETITEM) {
|
if (prim->name() == prim::kTupleGetItem) {
|
||||||
auto tuple_index = GetTupleGetItemIndex(cnode);
|
auto tuple_index = GetTupleGetItemIndex(cnode);
|
||||||
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
|
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
|
||||||
if (!layout_ptr) {
|
if (!layout_ptr) {
|
||||||
|
@ -2234,7 +2235,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// return -> tuple_getitem -> loss
|
// return -> tuple_getitem -> loss
|
||||||
if (current_prim->name() == TUPLE_GETITEM) {
|
if (current_prim->name() == prim::kTupleGetItem) {
|
||||||
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
||||||
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
||||||
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
||||||
|
@ -2491,7 +2492,7 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
|
||||||
}
|
}
|
||||||
|
|
||||||
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
|
||||||
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) {
|
if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetItem)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <stack>
|
#include <stack>
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "ir/graph_utils.h"
|
#include "ir/graph_utils.h"
|
||||||
|
@ -631,7 +632,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
|
||||||
MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
|
MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
|
||||||
}
|
}
|
||||||
TraceOutput(c->input(1));
|
TraceOutput(c->input(1));
|
||||||
} else if (name == "tuple_getitem") {
|
} else if (name == prim::kTupleGetItem) {
|
||||||
TraceOutputFromTupleGetItem(anf_out);
|
TraceOutputFromTupleGetItem(anf_out);
|
||||||
} else {
|
} else {
|
||||||
// add outputs;
|
// add outputs;
|
||||||
|
@ -1014,7 +1015,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
||||||
if (it != out_handle_cache_.end()) {
|
if (it != out_handle_cache_.end()) {
|
||||||
int ret = adpt->setInput(src, SizeToInt(i), it->second);
|
int ret = adpt->setInput(src, SizeToInt(i), it->second);
|
||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") {
|
if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kTupleGetItem) {
|
||||||
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
|
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
|
||||||
<< ":" << i << endl;
|
<< ":" << i << endl;
|
||||||
} else if (pred->isa<Parameter>()) {
|
} else if (pred->isa<Parameter>()) {
|
||||||
|
@ -1538,7 +1539,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// As for nodes with multi outputs, convert tuple_getitem to OutHandle
|
// As for nodes with multi outputs, convert tuple_getitem to OutHandle
|
||||||
if (name == "tuple_getitem") {
|
if (name == prim::kTupleGetItem) {
|
||||||
ConvertTupleGetItem(node);
|
ConvertTupleGetItem(node);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,19 +26,33 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace prim {
|
namespace prim {
|
||||||
constexpr auto kGather = "Gather";
|
constexpr auto kGather = "Gather";
|
||||||
|
// Arithmetic
|
||||||
|
constexpr auto kScalarAdd = "ScalarAdd";
|
||||||
|
constexpr auto kScalarSub = "ScalarSub";
|
||||||
|
constexpr auto kScalarMul = "ScalarMul";
|
||||||
|
constexpr auto kScalarDiv = "ScalarDiv";
|
||||||
|
constexpr auto kScalarFloordiv = "ScalarFloordiv";
|
||||||
|
constexpr auto kScalarMod = "ScalarMod";
|
||||||
|
constexpr auto kScalarPow = "ScalarPow";
|
||||||
|
constexpr auto kScalarTrunc = "ScalarTrunc";
|
||||||
|
constexpr auto kScalarFloor = "ScalarFloor";
|
||||||
|
constexpr auto kScalarUadd = "ScalarUadd";
|
||||||
|
constexpr auto kScalarUsub = "ScalarUsub";
|
||||||
|
constexpr auto kTupleGetItem = "TupleGetItem";
|
||||||
|
|
||||||
// Here list all primitives used in backend or some special primitives used by core.
|
// Here list all primitives used in backend or some special primitives used by core.
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
|
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>(kScalarAdd);
|
||||||
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
|
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>(kScalarSub);
|
||||||
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
|
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>(kScalarMul);
|
||||||
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
|
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>(kScalarDiv);
|
||||||
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
|
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>(kScalarFloordiv);
|
||||||
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
|
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>(kScalarMod);
|
||||||
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
|
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>(kScalarPow);
|
||||||
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
|
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>(kScalarTrunc);
|
||||||
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
|
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>(kScalarFloor);
|
||||||
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
|
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>(kScalarUadd);
|
||||||
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
|
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>(kScalarUsub);
|
||||||
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
|
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
|
||||||
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
|
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
|
||||||
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||||
|
@ -295,7 +309,7 @@ inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
|
||||||
|
|
||||||
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
|
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
|
||||||
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
|
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
|
||||||
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
|
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>(kTupleGetItem);
|
||||||
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
|
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
|
||||||
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
|
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
|
||||||
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
|
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
|
||||||
|
|
|
@ -19,9 +19,11 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem",
|
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||||
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem",
|
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem",
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""bprop primitives"""
|
"""bprop primitives"""
|
||||||
|
from mindspore.ops import _constants
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
|
@ -50,31 +51,31 @@ def bprop_relu_grad_grad(x, y, out, dout):
|
||||||
return dy, F.zeros_like(y)
|
return dy, F.zeros_like(y)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_add")
|
@bprops.register(_constants.kScalarAdd)
|
||||||
def bprop_scalar_add(x, y, out, dout):
|
def bprop_scalar_add(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_add`."""
|
"""Backpropagator for primitive `scalar_add`."""
|
||||||
return dout, dout
|
return dout, dout
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_mul")
|
@bprops.register(_constants.kScalarMul)
|
||||||
def bprop_scalar_mul(x, y, out, dout):
|
def bprop_scalar_mul(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_mul`."""
|
"""Backpropagator for primitive `scalar_mul`."""
|
||||||
return dout*y, dout*x
|
return dout*y, dout*x
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_sub")
|
@bprops.register(_constants.kScalarSub)
|
||||||
def bprop_scalar_sub(x, y, out, dout):
|
def bprop_scalar_sub(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_sub`."""
|
"""Backpropagator for primitive `scalar_sub`."""
|
||||||
return dout, -dout
|
return dout, -dout
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_div")
|
@bprops.register(_constants.kScalarDiv)
|
||||||
def bprop_scalar_div(x, y, out, dout):
|
def bprop_scalar_div(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_div`."""
|
"""Backpropagator for primitive `scalar_div`."""
|
||||||
return dout/y, (-dout) * (out/y)
|
return dout/y, (-dout) * (out/y)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_pow")
|
@bprops.register(_constants.kScalarPow)
|
||||||
def bprop_scalar_pow(x, y, out, dout):
|
def bprop_scalar_pow(x, y, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_pow`."""
|
"""Backpropagator for primitive `scalar_pow`."""
|
||||||
return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
|
return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
|
||||||
|
@ -86,13 +87,13 @@ def bprop_scalar_exp(x, out, dout):
|
||||||
return (dout * out,)
|
return (dout * out,)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_uadd")
|
@bprops.register(_constants.kScalarUadd)
|
||||||
def bprop_scalar_uadd(x, out, dout):
|
def bprop_scalar_uadd(x, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_uadd`."""
|
"""Backpropagator for primitive `scalar_uadd`."""
|
||||||
return (dout,)
|
return (dout,)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("scalar_usub")
|
@bprops.register(_constants.kScalarUsub)
|
||||||
def bprop_scalar_usub(x, out, dout):
|
def bprop_scalar_usub(x, out, dout):
|
||||||
"""Backpropagator for primitive `scalar_usub`."""
|
"""Backpropagator for primitive `scalar_usub`."""
|
||||||
return (-dout,)
|
return (-dout,)
|
||||||
|
@ -140,7 +141,7 @@ def bprop_scalar_cast(x, t, out, dout):
|
||||||
return F.scalar_cast(dout, F.typeof(x)), t
|
return F.scalar_cast(dout, F.typeof(x)), t
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("tuple_getitem")
|
@bprops.register(_constants.kTupleGetItem)
|
||||||
def bprop_tuple_getitem(data, idx, out, dout):
|
def bprop_tuple_getitem(data, idx, out, dout):
|
||||||
"""Backpropagator for primitive `tuple_getitem`."""
|
"""Backpropagator for primitive `tuple_getitem`."""
|
||||||
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||||
|
|
|
@ -18,11 +18,11 @@
|
||||||
"""The names of functional part are summarized here."""
|
"""The names of functional part are summarized here."""
|
||||||
|
|
||||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||||
|
from mindspore.ops import _constants
|
||||||
from .primitive import Primitive
|
from .primitive import Primitive
|
||||||
from . import operations as P
|
from . import operations as P
|
||||||
from .operations import _grad_ops
|
from .operations import _grad_ops
|
||||||
|
|
||||||
|
|
||||||
typeof = Primitive('typeof')
|
typeof = Primitive('typeof')
|
||||||
hastype = Primitive('hastype')
|
hastype = Primitive('hastype')
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
|
@ -96,7 +96,7 @@ depend = P.Depend()
|
||||||
identity = P.identity()
|
identity = P.identity()
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||||
list_getitem = Primitive('list_getitem')
|
list_getitem = Primitive('list_getitem')
|
||||||
list_setitem = Primitive('list_setitem')
|
list_setitem = Primitive('list_setitem')
|
||||||
dict_getitem = Primitive('dict_getitem')
|
dict_getitem = Primitive('dict_getitem')
|
||||||
|
@ -114,22 +114,22 @@ tuple_equal = Primitive("tuple_equal")
|
||||||
list_equal = Primitive("list_equal")
|
list_equal = Primitive("list_equal")
|
||||||
make_ref = Primitive("make_ref")
|
make_ref = Primitive("make_ref")
|
||||||
|
|
||||||
scalar_add = Primitive('scalar_add')
|
scalar_add = Primitive(_constants.kScalarAdd)
|
||||||
scalar_mul = Primitive('scalar_mul')
|
scalar_mul = Primitive(_constants.kScalarMul)
|
||||||
scalar_sub = Primitive('scalar_sub')
|
scalar_sub = Primitive(_constants.kScalarSub)
|
||||||
scalar_div = Primitive('scalar_div')
|
scalar_div = Primitive(_constants.kScalarDiv)
|
||||||
scalar_floordiv = Primitive('scalar_floordiv')
|
scalar_floordiv = Primitive(_constants.kScalarFloordiv)
|
||||||
scalar_log = Primitive('scalar_log')
|
scalar_log = Primitive('scalar_log')
|
||||||
scalar_pow = Primitive('scalar_pow')
|
scalar_pow = Primitive(_constants.kScalarPow)
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = Primitive('scalar_gt')
|
||||||
scalar_ge = Primitive('scalar_ge')
|
scalar_ge = Primitive('scalar_ge')
|
||||||
scalar_le = Primitive('scalar_le')
|
scalar_le = Primitive('scalar_le')
|
||||||
scalar_lt = Primitive('scalar_lt')
|
scalar_lt = Primitive('scalar_lt')
|
||||||
scalar_eq = Primitive('scalar_eq')
|
scalar_eq = Primitive('scalar_eq')
|
||||||
scalar_ne = Primitive('scalar_ne')
|
scalar_ne = Primitive('scalar_ne')
|
||||||
scalar_uadd = Primitive('scalar_uadd')
|
scalar_uadd = Primitive(_constants.kScalarUadd)
|
||||||
scalar_usub = Primitive('scalar_usub')
|
scalar_usub = Primitive(_constants.kScalarUsub)
|
||||||
scalar_mod = Primitive('scalar_mod')
|
scalar_mod = Primitive(_constants.kScalarMod)
|
||||||
string_eq = Primitive('string_equal')
|
string_eq = Primitive('string_equal')
|
||||||
string_concat = Primitive('string_concat')
|
string_concat = Primitive('string_concat')
|
||||||
bool_not = Primitive("bool_not")
|
bool_not = Primitive("bool_not")
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ class TestAnf : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestAnf, test_ValueNode) {
|
TEST_F(TestAnf, test_ValueNode) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_add");
|
auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
ValueNodePtr c = NewValueNode(prim);
|
ValueNodePtr c = NewValueNode(prim);
|
||||||
ASSERT_EQ(c->isa<ValueNode>(), true);
|
ASSERT_EQ(c->isa<ValueNode>(), true);
|
||||||
ASSERT_EQ(IsValueNode<Primitive>(c), true);
|
ASSERT_EQ(IsValueNode<Primitive>(c), true);
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "pipeline/jit/parse/parse.h"
|
#include "pipeline/jit/parse/parse.h"
|
||||||
#include "ir/graph_utils.h"
|
#include "ir/graph_utils.h"
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class TestCloner : public UT::Common {
|
class TestCloner : public UT::Common {
|
||||||
|
@ -89,7 +90,7 @@ TEST_F(TestCloner, test_clone_simple) {
|
||||||
Cloner cl2(gs);
|
Cloner cl2(gs);
|
||||||
auto g3 = cl2[g];
|
auto g3 = cl2[g];
|
||||||
|
|
||||||
std::vector<Primitive> results = {Primitive("scalar_add"), Primitive("scalar_mul"), Primitive("return")};
|
std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("return")};
|
||||||
AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
|
AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
|
||||||
common = d1 & d3;
|
common = d1 & d3;
|
||||||
for (auto& x : common) {
|
for (auto& x : common) {
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "pybind_api/ir/primitive_py.h"
|
#include "pybind_api/ir/primitive_py.h"
|
||||||
#include "pipeline/jit/parse/python_adapter.h"
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace prim {
|
namespace prim {
|
||||||
|
@ -34,52 +35,52 @@ class TestOps : public UT::Common {
|
||||||
|
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
TEST_F(TestOps, ScalarAddTest) {
|
TEST_F(TestOps, ScalarAddTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_add");
|
auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarAdd->name());
|
ASSERT_EQ(prim->name(), kPrimScalarAdd->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarSubTest) {
|
TEST_F(TestOps, ScalarSubTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_sub");
|
auto prim = std::make_shared<Primitive>(prim::kScalarSub);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarSub->name());
|
ASSERT_EQ(prim->name(), kPrimScalarSub->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarMulTest) {
|
TEST_F(TestOps, ScalarMulTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_mul");
|
auto prim = std::make_shared<Primitive>(prim::kScalarMul);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarMul->name());
|
ASSERT_EQ(prim->name(), kPrimScalarMul->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarDivTest) {
|
TEST_F(TestOps, ScalarDivTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_div");
|
auto prim = std::make_shared<Primitive>(prim::kScalarDiv);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarDiv->name());
|
ASSERT_EQ(prim->name(), kPrimScalarDiv->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarModTest) {
|
TEST_F(TestOps, ScalarModTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_mod");
|
auto prim = std::make_shared<Primitive>(prim::kScalarMod);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarMod->name());
|
ASSERT_EQ(prim->name(), kPrimScalarMod->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarPowTest) {
|
TEST_F(TestOps, ScalarPowTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_pow");
|
auto prim = std::make_shared<Primitive>(prim::kScalarPow);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarPow->name());
|
ASSERT_EQ(prim->name(), kPrimScalarPow->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarTruncTest) {
|
TEST_F(TestOps, ScalarTruncTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_trunc");
|
auto prim = std::make_shared<Primitive>(prim::kScalarTrunc);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarTrunc->name());
|
ASSERT_EQ(prim->name(), kPrimScalarTrunc->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarFloorTest) {
|
TEST_F(TestOps, ScalarFloorTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_floor");
|
auto prim = std::make_shared<Primitive>(prim::kScalarFloor);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarFloor->name());
|
ASSERT_EQ(prim->name(), kPrimScalarFloor->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarUaddTest) {
|
TEST_F(TestOps, ScalarUaddTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_uadd");
|
auto prim = std::make_shared<Primitive>(prim::kScalarUadd);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarUadd->name());
|
ASSERT_EQ(prim->name(), kPrimScalarUadd->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, ScalarUsubTest) {
|
TEST_F(TestOps, ScalarUsubTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_usub");
|
auto prim = std::make_shared<Primitive>(prim::kScalarUsub);
|
||||||
ASSERT_EQ(prim->name(), kPrimScalarUsub->name());
|
ASSERT_EQ(prim->name(), kPrimScalarUsub->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +188,7 @@ TEST_F(TestOps, MakeRecordTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, TupleGetItemTest) {
|
TEST_F(TestOps, TupleGetItemTest) {
|
||||||
auto prim = std::make_shared<Primitive>("tuple_getitem");
|
auto prim = std::make_shared<Primitive>(kTupleGetItem);
|
||||||
ASSERT_EQ(prim->name(), kPrimTupleGetItem->name());
|
ASSERT_EQ(prim->name(), kPrimTupleGetItem->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "frontend/operator/prim_to_function.h"
|
#include "frontend/operator/prim_to_function.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace prim {
|
namespace prim {
|
||||||
|
@ -33,7 +34,7 @@ class TestPrimFunc : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestPrimFunc, ScalarAddTest) {
|
TEST_F(TestPrimFunc, ScalarAddTest) {
|
||||||
auto prim = std::make_shared<Primitive>("scalar_add");
|
auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
FunctionPtr func = nullptr;
|
FunctionPtr func = nullptr;
|
||||||
PrimToFunction::GetInstance().GetFunction(prim, &func);
|
PrimToFunction::GetInstance().GetFunction(prim, &func);
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -154,7 +155,7 @@ TEST_F(TestPrim, test_list_map) {
|
||||||
AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
|
AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
|
||||||
AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
|
AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
|
||||||
auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
|
auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
|
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
|
||||||
|
|
||||||
args_spec_list.push_back(abstract_func);
|
args_spec_list.push_back(abstract_func);
|
||||||
|
@ -179,7 +180,7 @@ TEST_F(TestPrim, test_list_reduce) {
|
||||||
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
||||||
AbstractBasePtr abstract_v2 = FromValue(v1, false);
|
AbstractBasePtr abstract_v2 = FromValue(v1, false);
|
||||||
auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
|
auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
|
AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
|
||||||
|
|
||||||
args_spec_list.push_back(abstract_func);
|
args_spec_list.push_back(abstract_func);
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "ir/graph_utils.h"
|
#include "ir/graph_utils.h"
|
||||||
#include "utils/misc.h"
|
#include "utils/misc.h"
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -95,7 +96,7 @@ void TestSpecializeGraph::SetUp() {
|
||||||
// build func_graph beta
|
// build func_graph beta
|
||||||
ParameterPtr x1 = graph_beta_->add_parameter();
|
ParameterPtr x1 = graph_beta_->add_parameter();
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
inputs.push_back(NewValueNode(std::make_shared<Primitive>("scalar_add")));
|
inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd)));
|
||||||
inputs.push_back(x1);
|
inputs.push_back(x1);
|
||||||
inputs.push_back(y);
|
inputs.push_back(y);
|
||||||
CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
|
CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
|
||||||
|
@ -166,7 +167,7 @@ class MetaScalarAdd : public MetaFuncGraph {
|
||||||
FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
|
FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
|
||||||
ParameterPtr x = graph_g->add_parameter();
|
ParameterPtr x = graph_g->add_parameter();
|
||||||
ParameterPtr y = graph_g->add_parameter();
|
ParameterPtr y = graph_g->add_parameter();
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
inputs.push_back(NewValueNode(prim_scalar_add));
|
inputs.push_back(NewValueNode(prim_scalar_add));
|
||||||
inputs.push_back(x);
|
inputs.push_back(x);
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include "pipeline/jit/resource.h"
|
#include "pipeline/jit/resource.h"
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -96,7 +97,7 @@ class MetaScalarAdd : public MetaFuncGraph {
|
||||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||||
ParameterPtr x = fg->add_parameter();
|
ParameterPtr x = fg->add_parameter();
|
||||||
ParameterPtr y = fg->add_parameter();
|
ParameterPtr y = fg->add_parameter();
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
inputs.push_back(NewValueNode(prim_scalar_add));
|
inputs.push_back(NewValueNode(prim_scalar_add));
|
||||||
inputs.push_back(x);
|
inputs.push_back(x);
|
||||||
|
@ -161,7 +162,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
|
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||||
|
@ -388,7 +389,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
|
||||||
args_spec.push_back(abstract_v1);
|
args_spec.push_back(abstract_v1);
|
||||||
args_spec.push_back(abstract_v2);
|
args_spec.push_back(abstract_v2);
|
||||||
|
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
|
||||||
|
@ -418,7 +419,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
|
||||||
AbstractBasePtr base1 = FromValue(x1, false);
|
AbstractBasePtr base1 = FromValue(x1, false);
|
||||||
AbstractBasePtr base2 = FromValue(x2, false);
|
AbstractBasePtr base2 = FromValue(x2, false);
|
||||||
AbstractBasePtrList base_list = {base1, base2};
|
AbstractBasePtrList base_list = {base1, base2};
|
||||||
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract();
|
auto res = EvalOnePrim(std::make_shared<Primitive>(prim::kScalarAdd), base_list)->abstract();
|
||||||
MS_LOG(INFO) << "result spec: " << res->ToString();
|
MS_LOG(INFO) << "result spec: " << res->ToString();
|
||||||
AbstractBasePtr exp = FromValue(x3, false);
|
AbstractBasePtr exp = FromValue(x3, false);
|
||||||
MS_LOG(INFO) << "result exp: " << exp->ToString();
|
MS_LOG(INFO) << "result exp: " << exp->ToString();
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" Test for GraphCloner """
|
""" Test for GraphCloner """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive(Constants.kScalarAdd)
|
||||||
scalar_mul = Primitive('scalar_mul')
|
scalar_mul = Primitive(Constants.kScalarMul)
|
||||||
|
|
||||||
|
|
||||||
def test_clone_simple():
|
def test_clone_simple():
|
||||||
|
|
|
@ -18,9 +18,11 @@ import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
from tests.ut.python.model.resnet import resnet50
|
from tests.ut.python.model.resnet import resnet50
|
||||||
|
|
||||||
scala_add = Primitive('scalar_add')
|
|
||||||
|
scala_add = Primitive(Constants.kScalarAdd)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
|
||||||
|
@ -26,9 +27,9 @@ from mindspore.ops.operations import _grad_ops as G
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
scalar_add = Primitive('scalar_add')
|
scalar_add = Primitive(Constants.kScalarAdd)
|
||||||
scalar_mul = Primitive('scalar_mul')
|
scalar_mul = Primitive(Constants.kScalarMul)
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
switch = Primitive('switch')
|
switch = Primitive('switch')
|
||||||
|
|
||||||
|
|
||||||
|
@ -347,7 +348,7 @@ def test_inline_while(tag):
|
||||||
def test_cse(tag):
|
def test_cse(tag):
|
||||||
""" test_cse """
|
""" test_cse """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
scalar_div = Primitive('scalar_div')
|
scalar_div = Primitive(Constants.kScalarDiv)
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def test_f1(x, y):
|
def test_f1(x, y):
|
||||||
|
@ -920,9 +921,9 @@ def test_convert_switch_ops(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
ge_switch = Primitive('GeSwitch')
|
ge_switch = Primitive('GeSwitch')
|
||||||
merge = Primitive('Merge')
|
merge = Primitive('Merge')
|
||||||
add = Primitive('Add')
|
add = Primitive(Constants.kScalarAdd)
|
||||||
neg = Primitive('Neg')
|
neg = Primitive('Neg')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
|
|
|
@ -18,8 +18,9 @@ import mindspore.nn as nn
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive(Constants.kScalarAdd)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
|
|
||||||
Add = P.Add()
|
Add = P.Add()
|
||||||
Sub = P.Sub()
|
Sub = P.Sub()
|
||||||
|
@ -24,7 +26,7 @@ Sqrt = P.Sqrt()
|
||||||
Square = P.Square()
|
Square = P.Square()
|
||||||
Assign = P.Assign()
|
Assign = P.Assign()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
AdamApplyOne = Primitive('AdamApplyOne')
|
AdamApplyOne = Primitive('AdamApplyOne')
|
||||||
AdamApplyOneAssign = Primitive('AdamApplyOneAssign')
|
AdamApplyOneAssign = Primitive('AdamApplyOneAssign')
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
|
@ -25,7 +26,7 @@ real_div = P.RealDiv()
|
||||||
sub = P.Sub()
|
sub = P.Sub()
|
||||||
Assign = P.Assign()
|
Assign = P.Assign()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
|
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
|
||||||
adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign')
|
adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign')
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
BatchNorm = P.BatchNorm()
|
BatchNorm = P.BatchNorm()
|
||||||
BNTrainingReduce = Primitive('BNTrainingReduce')
|
BNTrainingReduce = Primitive('BNTrainingReduce')
|
||||||
BNTrainingUpdateV2 = Primitive('BNTrainingUpdateV2')
|
BNTrainingUpdateV2 = Primitive('BNTrainingUpdateV2')
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
BatchNormGradTraining = G.BatchNormGrad(is_training=True)
|
BatchNormGradTraining = G.BatchNormGrad(is_training=True)
|
||||||
BatchNormGradInfer = G.BatchNormGrad(is_training=False)
|
BatchNormGradInfer = G.BatchNormGrad(is_training=False)
|
||||||
BNInferGrad = Primitive('BNInferGrad')
|
BNInferGrad = Primitive('BNInferGrad')
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
batch_norm_grad = G.BatchNormGrad(is_training=True)
|
batch_norm_grad = G.BatchNormGrad(is_training=True)
|
||||||
bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
|
bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
|
||||||
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')
|
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,11 +15,12 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
batch_norm = P.BatchNorm(is_training=False)
|
batch_norm = P.BatchNorm(is_training=False)
|
||||||
bn_infer = Primitive('BNInfer')
|
bn_infer = Primitive('BNInfer')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,11 +15,12 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
batch_norm_grad = G.BatchNormGrad(is_training=False)
|
batch_norm_grad = G.BatchNormGrad(is_training=False)
|
||||||
bn_infer_grad = Primitive('BNInferGrad')
|
bn_infer_grad = Primitive('BNInferGrad')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
bn_grad = G.BatchNormGrad(is_training=True)
|
bn_grad = G.BatchNormGrad(is_training=True)
|
||||||
bn_grad1 = Primitive('BNGrad1')
|
bn_grad1 = Primitive('BNGrad1')
|
||||||
bn_grad2 = Primitive('BNGrad2')
|
bn_grad2 = Primitive('BNGrad2')
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
bn = P.BatchNorm(is_training=True)
|
bn = P.BatchNorm(is_training=True)
|
||||||
fused_bn1 = Primitive('FusedBN1')
|
fused_bn1 = Primitive('FusedBN1')
|
||||||
fused_bn2 = Primitive('FusedBN2')
|
fused_bn2 = Primitive('FusedBN2')
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
select = P.Select()
|
select = P.Select()
|
||||||
maximum = P.Maximum()
|
maximum = P.Maximum()
|
||||||
|
@ -21,7 +22,7 @@ sqrt = P.Sqrt()
|
||||||
greater = P.Greater()
|
greater = P.Greater()
|
||||||
clip_by_norm_no_div_square_sum = Primitive('ClipByNormNoDivSum')
|
clip_by_norm_no_div_square_sum = Primitive('ClipByNormNoDivSum')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,12 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
maximum = P.Maximum()
|
maximum = P.Maximum()
|
||||||
minimum = P.Minimum()
|
minimum = P.Minimum()
|
||||||
clip_by_value = Primitive('ClipByValue')
|
clip_by_value = Primitive('ClipByValue')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,13 +14,14 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
reduce_sum = P.ReduceSum()
|
reduce_sum = P.ReduceSum()
|
||||||
confusion_mul_grad = Primitive('ConfusionMulGrad')
|
confusion_mul_grad = Primitive('ConfusionMulGrad')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
axis = 1
|
axis = 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
|
@ -20,7 +21,7 @@ reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
sub = P.Sub()
|
sub = P.Sub()
|
||||||
confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad')
|
confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
axis = 2
|
axis = 2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_get_item = Primitive("tuple_getitem")
|
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
||||||
LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0)
|
LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,14 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
relu = P.ReLU()
|
relu = P.ReLU()
|
||||||
relu_grad = Primitive('ReluGrad')
|
relu_grad = Primitive('ReluGrad')
|
||||||
relu_v2 = Primitive('ReLUV2')
|
relu_v2 = Primitive('ReLUV2')
|
||||||
relu_grad_v2 = Primitive('ReluGradV2')
|
relu_grad_v2 = Primitive('ReluGradV2')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -17,12 +17,13 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
AssignSub = P.AssignSub()
|
AssignSub = P.AssignSub()
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
Sub = P.Sub()
|
Sub = P.Sub()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
BatchNorm = P.BatchNorm()
|
BatchNorm = P.BatchNorm()
|
||||||
Cast = P.Cast()
|
Cast = P.Cast()
|
||||||
BNTrainingReduce = Primitive('BNTrainingReduce')
|
BNTrainingReduce = Primitive('BNTrainingReduce')
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "")
|
get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "")
|
||||||
memcpy_async = Primitive('memcpy_async')
|
memcpy_async = Primitive('memcpy_async')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
all_reduce = P.AllReduce()
|
all_reduce = P.AllReduce()
|
||||||
broadcast = P.Broadcast(1)
|
broadcast = P.Broadcast(1)
|
||||||
memcpy_async = Primitive('memcpy_async')
|
memcpy_async = Primitive('memcpy_async')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
apply_momentun = P.ApplyMomentum()
|
apply_momentun = P.ApplyMomentum()
|
||||||
control_depend = P.ControlDepend()
|
control_depend = P.ControlDepend()
|
||||||
relu = P.ReLU()
|
relu = P.ReLU()
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
|
@ -15,14 +15,15 @@
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
allreduce = P.AllReduce()
|
allreduce = P.AllReduce()
|
||||||
allreduce.add_prim_attr('fusion', 1)
|
allreduce.add_prim_attr('fusion', 1)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive("make_tuple")
|
||||||
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
|
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
|
||||||
bn = P.FusedBatchNorm()
|
bn = P.FusedBatchNorm()
|
||||||
relu = P.ReLU()
|
relu = P.ReLU()
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Add = P.Add()
|
Add = P.Add()
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
|
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
|
||||||
Rsqrt = P.Rsqrt()
|
Rsqrt = P.Rsqrt()
|
||||||
Sqrt = P.Sqrt()
|
Sqrt = P.Sqrt()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
LambNextMV = Primitive('LambNextMV')
|
LambNextMV = Primitive('LambNextMV')
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Add = P.Add()
|
Add = P.Add()
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
|
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
|
||||||
Rsqrt = P.Rsqrt()
|
Rsqrt = P.Rsqrt()
|
||||||
Sqrt = P.Sqrt()
|
Sqrt = P.Sqrt()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
LambNextMVWithDecay = Primitive('LambNextMVWithDecay')
|
LambNextMVWithDecay = Primitive('LambNextMVWithDecay')
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
|
@ -21,7 +22,7 @@ real_div = P.RealDiv()
|
||||||
rsqrt = P.Rsqrt()
|
rsqrt = P.Rsqrt()
|
||||||
sqrt = P.Sqrt()
|
sqrt = P.Sqrt()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1')
|
LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,14 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Add = P.Add()
|
Add = P.Add()
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
Sqrt = P.Sqrt()
|
Sqrt = P.Sqrt()
|
||||||
Square = P.Square()
|
Square = P.Square()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
LambNextRight = Primitive('LambNextRight')
|
LambNextRight = Primitive('LambNextRight')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
select = P.Select()
|
select = P.Select()
|
||||||
maximum = P.Maximum()
|
maximum = P.Maximum()
|
||||||
|
@ -24,7 +25,7 @@ mul = P.Mul()
|
||||||
sub = P.Sub()
|
sub = P.Sub()
|
||||||
lamb_update_with_lr = Primitive('LambUpdateWithLR')
|
lamb_update_with_lr = Primitive('LambUpdateWithLR')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Sub = P.Sub()
|
Sub = P.Sub()
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
|
@ -21,7 +22,7 @@ RealDiv = P.RealDiv()
|
||||||
Select = P.Select()
|
Select = P.Select()
|
||||||
Greater = P.Greater()
|
Greater = P.Greater()
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2')
|
LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
lars_v2 = Primitive('LarsV2')
|
lars_v2 = Primitive('LarsV2')
|
||||||
square_sum_all = Primitive('SquareSumAll')
|
square_sum_all = Primitive('SquareSumAll')
|
||||||
lars_v2_update = Primitive('LarsV2Update')
|
lars_v2_update = Primitive('LarsV2Update')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Add = P.Add()
|
Add = P.Add()
|
||||||
Cast = P.Cast()
|
Cast = P.Cast()
|
||||||
LayerNormBetaGammaBackprop = Primitive('LayerNormBetaGammaBackprop')
|
LayerNormBetaGammaBackprop = Primitive('LayerNormBetaGammaBackprop')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
layer_norm_grad = G.LayerNormGrad()
|
layer_norm_grad = G.LayerNormGrad()
|
||||||
layer_norm_x_backprop = Primitive('LayerNormXBackprop')
|
layer_norm_x_backprop = Primitive('LayerNormXBackprop')
|
||||||
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop')
|
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop')
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
|
|
|
@ -16,11 +16,12 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
ApplyMomentum = P.ApplyMomentum()
|
ApplyMomentum = P.ApplyMomentum()
|
||||||
FusedMulApplyMomentum = Primitive('FusedMulApplyMomentum')
|
FusedMulApplyMomentum = Primitive('FusedMulApplyMomentum')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
constant = Tensor(1.0, mstype.float32)
|
constant = Tensor(1.0, mstype.float32)
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
fused_mul_add = Primitive('FusedMulAdd')
|
fused_mul_add = Primitive('FusedMulAdd')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -16,12 +16,13 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
fused_mul_addn = Primitive('FusedMulAddN')
|
fused_mul_addn = Primitive('FusedMulAddN')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
scalar = Tensor(1.0, mstype.float32)
|
scalar = Tensor(1.0, mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
reduce_min = P.ReduceMin(keep_dims=False)
|
reduce_min = P.ReduceMin(keep_dims=False)
|
||||||
reduce_min1 = Primitive('ReduceMin')
|
reduce_min1 = Primitive('ReduceMin')
|
||||||
reduce_min2 = Primitive('ReduceMin')
|
reduce_min2 = Primitive('ReduceMin')
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
BatchNorm = P.BatchNorm(is_training=True)
|
BatchNorm = P.BatchNorm(is_training=True)
|
||||||
BNTrainingReduce = Primitive('BNTrainingReduce')
|
BNTrainingReduce = Primitive('BNTrainingReduce')
|
||||||
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')
|
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')
|
||||||
|
|
|
@ -14,13 +14,14 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
Mul = P.Mul()
|
Mul = P.Mul()
|
||||||
ReduceSum = P.ReduceSum(keep_dims=True)
|
ReduceSum = P.ReduceSum(keep_dims=True)
|
||||||
Sub = P.Sub()
|
Sub = P.Sub()
|
||||||
SoftmaxGradExt = Primitive('SoftmaxGradExt')
|
SoftmaxGradExt = Primitive('SoftmaxGradExt')
|
||||||
MakeTuple = Primitive('make_tuple')
|
MakeTuple = Primitive('make_tuple')
|
||||||
TupleGetItem = Primitive('tuple_getitem')
|
TupleGetItem = Primitive(Constants.kTupleGetItem)
|
||||||
axes = (2, 3)
|
axes = (2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,11 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
split = P.Split(0, 8)
|
split = P.Split(0, 8)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
splitv = Primitive('SplitV')
|
splitv = Primitive('SplitV')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
square = P.Square()
|
square = P.Square()
|
||||||
reduce_sum = P.ReduceSum()
|
reduce_sum = P.ReduceSum()
|
||||||
square_sumv1 = Primitive('SquareSumV1')
|
square_sumv1 = Primitive('SquareSumV1')
|
||||||
|
|
|
@ -14,12 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tensor_scatter_update = P.TensorScatterUpdate()
|
tensor_scatter_update = P.TensorScatterUpdate()
|
||||||
tensor_move = Primitive('TensorMove')
|
tensor_move = Primitive('TensorMove')
|
||||||
scatter_nd_update = Primitive('ScatterNdUpdate')
|
scatter_nd_update = Primitive('ScatterNdUpdate')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
TopK = P.TopK()
|
TopK = P.TopK()
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||||
num_segments = 4
|
num_segments = 4
|
||||||
padding = Primitive('Padding')
|
padding = Primitive('Padding')
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
add = P.Add()
|
add = P.Add()
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" multi_relu_case """
|
""" multi_relu_case """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
|
|
||||||
# Test user define ops
|
# Test user define ops
|
||||||
|
@ -21,7 +22,7 @@ def get_test_ops_fn():
|
||||||
return test_ops_f
|
return test_ops_f
|
||||||
|
|
||||||
|
|
||||||
scalar_mul = Primitive('scalar_mul')
|
scalar_mul = Primitive(Constants.kScalarMul)
|
||||||
|
|
||||||
|
|
||||||
def test_ops_f(x, y):
|
def test_ops_f(x, y):
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" vm_test """
|
""" vm_test """
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
|
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive(Constants.kScalarAdd)
|
||||||
scala_mul = Primitive('scalar_mul')
|
scala_mul = Primitive(Constants.kScalarMul)
|
||||||
scalar_gt = Primitive('scalar_gt')
|
scalar_gt = Primitive('scalar_gt')
|
||||||
|
|
||||||
|
|
||||||
def scalar_add(x, y):
|
def ScalarAdd(x, y):
|
||||||
"""Implement `scalar_add`."""
|
"""Implement `scalar_add`."""
|
||||||
return scala_add(x, y)
|
return scala_add(x, y)
|
||||||
|
|
||||||
|
|
||||||
def scalar_mul(x, y):
|
def ScalarMul(x, y):
|
||||||
"""Implement `scalar_mul`."""
|
"""Implement `scalar_mul`."""
|
||||||
return scala_mul(x, y)
|
return scala_mul(x, y)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ir/manager.h"
|
#include "ir/manager.h"
|
||||||
#include "pipeline/jit/static_analysis/prim.h"
|
#include "pipeline/jit/static_analysis/prim.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace validator {
|
namespace validator {
|
||||||
|
@ -35,7 +36,7 @@ class TestValidator : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestValidator, ValidateOperation01) {
|
TEST_F(TestValidator, ValidateOperation01) {
|
||||||
auto node = NewValueNode(std::make_shared<Primitive>("scalar_add"));
|
auto node = NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd));
|
||||||
ValidateOperation(node);
|
ValidateOperation(node);
|
||||||
// normally, the above statement should not exit, so expected the following statement execute
|
// normally, the above statement should not exit, so expected the following statement execute
|
||||||
EXPECT_TRUE(true);
|
EXPECT_TRUE(true);
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/convert_utils_py.h"
|
#include "utils/convert_utils_py.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace compile {
|
namespace compile {
|
||||||
|
@ -46,7 +47,7 @@ class TestCompileSegmentRunner : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
|
TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
|
||||||
FuncGraphPtr g = get_py_fun_("scalar_add");
|
FuncGraphPtr g = get_py_fun_(prim::kScalarAdd);
|
||||||
// g was managed by local variable manager in get_py_fun_ and that manager will be freed as no reference.
|
// g was managed by local variable manager in get_py_fun_ and that manager will be freed as no reference.
|
||||||
// so a new manager should be declared to make get_outputs() in segment_runner.cc happy.
|
// so a new manager should be declared to make get_outputs() in segment_runner.cc happy.
|
||||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
|
||||||
|
@ -62,7 +63,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
|
TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
|
||||||
FuncGraphPtr g = get_py_fun_("scalar_mul");
|
FuncGraphPtr g = get_py_fun_(prim::kScalarMul);
|
||||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);
|
||||||
|
|
||||||
BackendPtr b = std::make_shared<Backend>("vm");
|
BackendPtr b = std::make_shared<Backend>("vm");
|
||||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||||
cancel_new_parameter, set_reopt
|
cancel_new_parameter, set_reopt
|
||||||
from mindspore.common.api import _generate_pip_args
|
from mindspore.common.api import _generate_pip_args
|
||||||
|
@ -296,12 +297,12 @@ def test_imm_target():
|
||||||
pattern = Call(P.Softmax(), [x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
imm = Imm(0)
|
imm = Imm(0)
|
||||||
target_0 = Call("make_tuple", [pattern])
|
target_0 = Call("make_tuple", [pattern])
|
||||||
target = Call("tuple_getitem", [target_0, imm])
|
target = Call(Constants.kTupleGetItem, [target_0, imm])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(softmax_pass)
|
unregiste_pass(softmax_pass)
|
||||||
assert "make_tuple" in transformed_repr
|
assert "make_tuple" in transformed_repr
|
||||||
assert "tuple_getitem" in transformed_repr
|
assert Constants.kTupleGetItem in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
|
|
||||||
def test_gen_new_parameter():
|
def test_gen_new_parameter():
|
||||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import ms_function
|
from mindspore.common.api import ms_function
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import _constants
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -28,7 +29,7 @@ from ...ut_filter import non_graph_engine
|
||||||
|
|
||||||
|
|
||||||
tensor_add = P.Add()
|
tensor_add = P.Add()
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive(_constants.kScalarAdd)
|
||||||
add = C.MultitypeFuncGraph('add')
|
add = C.MultitypeFuncGraph('add')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,12 +21,13 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import _constants
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from ...ut_filter import non_graph_engine
|
from ...ut_filter import non_graph_engine
|
||||||
|
|
||||||
tensor_add = P.Add()
|
tensor_add = P.Add()
|
||||||
op_add = P.AddN()
|
op_add = P.AddN()
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive(_constants.kScalarAdd)
|
||||||
add = C.MultitypeFuncGraph('add')
|
add = C.MultitypeFuncGraph('add')
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue