!11915 Change TensorAdd to Add, merge from r1.1 to master

From: @liangzhibo
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-02 09:16:08 +08:00 committed by Gitee
commit e897eb4c41
259 changed files with 608 additions and 607 deletions

File diff suppressed because one or more lines are too long

View File

@ -36,24 +36,24 @@ def expand_biasadd(expand_info):
'ExpandDims', [input_y], attrs={'axis': 1}) 'ExpandDims', [input_y], attrs={'axis': 1})
input_y_expand = graph_builder.emit( input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2}) 'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit('TensorAdd', [input_x, input_y_expand]) result = graph_builder.emit('Add', [input_x, input_y_expand])
elif input_x.data_format == "DefaultFormat": elif input_x.data_format == "DefaultFormat":
if len(input_x.shape) == 2: if len(input_x.shape) == 2:
result = graph_builder.emit('TensorAdd', [input_x, input_y]) result = graph_builder.emit('Add', [input_x, input_y])
elif len(input_x.shape) == 3: elif len(input_x.shape) == 3:
input_y_expand = graph_builder.emit( input_y_expand = graph_builder.emit(
'ExpandDims', [input_y], attrs={'axis': 1}) 'ExpandDims', [input_y], attrs={'axis': 1})
result = graph_builder.emit( result = graph_builder.emit(
'TensorAdd', [input_x, input_y_expand]) 'Add', [input_x, input_y_expand])
else: else:
input_y_expand = graph_builder.emit( input_y_expand = graph_builder.emit(
'ExpandDims', [input_y], attrs={'axis': 1}) 'ExpandDims', [input_y], attrs={'axis': 1})
input_y_expand = graph_builder.emit( input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2}) 'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit( result = graph_builder.emit(
'TensorAdd', [input_x, input_y_expand]) 'Add', [input_x, input_y_expand])
else: else:
result = graph_builder.emit('TensorAdd', [input_x, input_y]) result = graph_builder.emit('Add', [input_x, input_y])
# set graph output. # set graph output.
graph_scope.set_output(result) graph_scope.set_output(result)

View File

@ -49,13 +49,13 @@ def expand_fusedadam(expand_info):
# compute result # compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient]) grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])

View File

@ -52,16 +52,16 @@ def expand_fusedadamweightdecay(expand_info):
# compute result # compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient]) grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay]) update = graph_builder.emit('Add', [update, param_with_weight_decay])
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])

View File

@ -42,7 +42,7 @@ def expand_gelu(expand_info):
pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1]) tanh_res = graph_builder.emit('Add', [input_x, mul_1])
const_csvalue_sqrt_two_div_pi = graph_builder.value( const_csvalue_sqrt_two_div_pi = graph_builder.value(
tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
@ -51,7 +51,7 @@ def expand_gelu(expand_info):
tanh_y = graph_builder.emit('Tanh', [y]) tanh_y = graph_builder.emit('Tanh', [y])
const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one]) tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
result = graph_builder.emit('Mul', [const_half, mul_x]) result = graph_builder.emit('Mul', [const_half, mul_x])

View File

@ -55,18 +55,18 @@ def expand_gelugrad(expand_info):
# cal mul_right # cal mul_right
mul_double = graph_builder.emit('Mul', [input_x, input_x]) mul_double = graph_builder.emit('Mul', [input_x, input_x])
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri]) mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para # cal tanh_para
mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue]) mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
# cal 0.5 * (1.0 + tanh(tahn_para)) # cal 0.5 * (1.0 + tanh(tahn_para))
tanh_res = graph_builder.emit('Tanh', [tanh_para]) tanh_res = graph_builder.emit('Tanh', [tanh_para])
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res]) tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res])
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one]) half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right # cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
@ -77,7 +77,7 @@ def expand_gelugrad(expand_info):
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right]) mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
# cal result # cal result
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
result = graph_builder.emit('Mul', [input_dy, result_tmp]) result = graph_builder.emit('Mul', [input_dy, result_tmp])
# set graph output. # set graph output.

View File

@ -68,13 +68,13 @@ def expand_layernorm(expand_info):
# Calculate normalize # Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean]) normalize_sub = graph_builder.emit('Sub', [input_x, mean])
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v]) normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate # Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta]) res = graph_builder.emit('Add', [scale_mul, input_beta])
# set graph output. # set graph output.
graph_scope.set_output(res, mean, variance) graph_scope.set_output(res, mean, variance)

View File

@ -66,7 +66,7 @@ def expand_layernormgrad(expand_info):
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
# cal dg db # cal dg db
var_eps = graph_builder.emit('TensorAdd', [variance, eps]) var_eps = graph_builder.emit('Add', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
x_sub_mean = graph_builder.emit('Sub', [x, mean]) x_sub_mean = graph_builder.emit('Sub', [x, mean])
@ -100,10 +100,10 @@ def expand_layernormgrad(expand_info):
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2]) dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3]) dx = graph_builder.emit('Add', [dx_tmp, dx_3])
# set graph output. # set graph output.
graph_scope.set_output(dx, dg, db) graph_scope.set_output(dx, dg, db)

View File

@ -131,7 +131,7 @@ class PrimLib:
] ]
primtives = { primtives = {
'TensorAdd': Prim(ELEMWISE), 'Add': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE), 'Abs': Prim(ELEMWISE),
'Neg': Prim(ELEMWISE), 'Neg': Prim(ELEMWISE),
'Mul': Prim(ELEMWISE), 'Mul': Prim(ELEMWISE),

View File

@ -238,7 +238,7 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out,
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == prim::kPrimTensorAdd->name()) { if (kernel_name == prim::kPrimAdd->name()) {
operate_type_ = ADD; operate_type_ = ADD;
} else if (kernel_name == prim::kPrimSub->name()) { } else if (kernel_name == prim::kPrimSub->name()) {
operate_type_ = SUB; operate_type_ = SUB;

View File

@ -37,8 +37,7 @@ class TensorAddCPUKernel : public MKLCPUKernel {
}; };
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
TensorAdd, Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddCPUKernel); TensorAddCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -51,8 +51,7 @@ MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float) BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float) BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
FloorDiv, FloorDiv,
@ -103,8 +102,7 @@ MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half) BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half) BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
FloorDiv, FloorDiv,
@ -133,7 +131,7 @@ MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -171,7 +169,7 @@ MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),

View File

@ -145,7 +145,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = { static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, {"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, {"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN},
}; };

View File

@ -1063,7 +1063,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
static std::map<std::string, std::string> buffer_fussion_op_map = { static std::map<std::string, std::string> buffer_fussion_op_map = {
{parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}};
string result = origin_type; string result = origin_type;
auto iter = buffer_fussion_op_map.find(origin_type); auto iter = buffer_fussion_op_map.find(origin_type);
if (iter != buffer_fussion_op_map.end()) { if (iter != buffer_fussion_op_map.end()) {

View File

@ -99,7 +99,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
auto eltwise_input = cnode->input(1); auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input); MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
} }
} }

View File

@ -28,7 +28,7 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
} }
@ -41,7 +41,7 @@ const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
} }
@ -54,7 +54,7 @@ const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
} }
@ -67,7 +67,7 @@ const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
} }
@ -80,7 +80,7 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
} }
@ -94,7 +94,7 @@ const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -114,7 +114,7 @@ const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -134,7 +134,7 @@ const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -154,7 +154,7 @@ const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -174,7 +174,7 @@ const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});

View File

@ -38,8 +38,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
mul_x_input_vars_.push_back(std::make_shared<Var>()); mul_x_input_vars_.push_back(std::make_shared<Var>());
} }
add2_y_ = std::make_shared<Var>(); add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
} }

View File

@ -59,10 +59,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, input4_, add3}); VectorRef mul5({prim::kPrimMul, input4_, add3});
VectorRef sub0({prim::kPrimSub, input3_, mul5}); VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0; return sub0;
@ -79,10 +79,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5}); VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0; return sub0;
@ -99,10 +99,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5}); VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0; return sub0;
@ -119,10 +119,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5}); VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0; return sub0;
@ -139,10 +139,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5}); VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0; return sub0;
@ -159,10 +159,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, input4_, add3}); VectorRef mul5({prim::kPrimMul, input4_, add3});
VectorRef sub0({sub0_var_, input3_, mul5}); VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -184,10 +184,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5}); VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -209,10 +209,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5}); VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -234,10 +234,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5}); VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -259,10 +259,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3}); VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1}); VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2}); VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_}); VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5}); VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});

View File

@ -38,8 +38,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
mul3_x_ = std::make_shared<Var>(); mul3_x_ = std::make_shared<Var>();
mul4_x_ = std::make_shared<Var>(); mul4_x_ = std::make_shared<Var>();
add2_y_ = std::make_shared<Var>(); add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
} }
~AdamApplyOneWithDecayRule() override = default; ~AdamApplyOneWithDecayRule() override = default;

View File

@ -130,11 +130,11 @@ const BaseRef LambNextMVRuleCond1::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2}); auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); return VectorRef({prim::kPrimAdd, mul4, real_div2});
} }
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
@ -147,7 +147,7 @@ BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4; return real_div4;
} }
@ -166,11 +166,11 @@ const BaseRef LambNextMVRuleCond2::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2}); auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); return VectorRef({prim::kPrimAdd, mul4, real_div2});
} }
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
@ -183,7 +183,7 @@ BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4; return real_div4;
} }
@ -202,11 +202,11 @@ const BaseRef LambNextMVRuleCond3::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2}); auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); return VectorRef({prim::kPrimAdd, mul4, real_div2});
} }
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
@ -219,7 +219,7 @@ BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4; return real_div4;
} }
@ -238,11 +238,11 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2}); auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); return VectorRef({prim::kPrimAdd, real_div2, mul4});
} }
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
@ -255,7 +255,7 @@ BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4; return real_div4;
} }

View File

@ -49,8 +49,8 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass {
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
} }
~LambNextMVRule() override = default; ~LambNextMVRule() override = default;
const BaseRef DefinePattern() const override = 0; const BaseRef DefinePattern() const override = 0;

View File

@ -124,10 +124,10 @@ BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs}); VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3; return add3;
} }
@ -141,14 +141,14 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5; return add5;
} }
@ -165,10 +165,10 @@ BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs}); VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3; return add3;
} }
@ -182,14 +182,14 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5; return add5;
} }
@ -206,10 +206,10 @@ BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs}); VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3; return add3;
} }
@ -223,14 +223,14 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5; return add5;
} }
@ -248,10 +248,10 @@ BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs}); VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4});
return add3; return add3;
} }
@ -265,14 +265,14 @@ const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4});
return add5; return add5;
} }
} // namespace opt } // namespace opt

View File

@ -38,8 +38,8 @@ class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass {
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
} }
~LambNextMVWithDecayRule() override = default; ~LambNextMVWithDecayRule() override = default;

View File

@ -66,7 +66,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false; return false;
} }
auto add5 = node->cast<CNodePtr>(); auto add5 = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) {
return false; return false;
} }
auto real_div4_anf = add5->input(1); auto real_div4_anf = add5->input(1);
@ -82,7 +82,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false; return false;
} }
auto add4 = add4_anf->cast<CNodePtr>(); auto add4 = add4_anf->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) {
return false; return false;
} }
auto sqrt1_anf = add4->input(1); auto sqrt1_anf = add4->input(1);
@ -140,17 +140,17 @@ const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const {
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); VectorRef add1({prim::kPrimAdd, mul2, mul3});
VectorRef real_div1({prim_real_div, add1, input2_}); VectorRef real_div1({prim_real_div, add1, input2_});
VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); VectorRef add2({prim::kPrimAdd, real_div1, add2_y_});
VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); VectorRef mul0({prim::kPrimMul, mul0_x_, input4_});
VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_});
VectorRef sqrt0({prim_rsqrt, add2}); VectorRef sqrt0({prim_rsqrt, add2});
VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); VectorRef add0({prim::kPrimAdd, mul0, mul1});
VectorRef real_div0({prim_real_div, add0, input5_}); VectorRef real_div0({prim_real_div, add0, input5_});
VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); VectorRef mul4({prim::kPrimMul, mul4_x_, input6_});
VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); VectorRef add3({prim::kPrimAdd, real_div2, mul4});
return add3; return add3;
} }

View File

@ -54,7 +54,7 @@ const BaseRef LambNextRightRule::DefinePattern() const {
VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})});
VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3});
return VectorRef( return VectorRef(
{prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); {prim::kPrimAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
} }
const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,

View File

@ -32,7 +32,7 @@ class LambNextRightRule : public PatternProcessPass {
mul3_x_(std::make_shared<Var>()), mul3_x_(std::make_shared<Var>()),
true_div1_recip_(std::make_shared<Var>()), true_div1_recip_(std::make_shared<Var>()),
add2_y_(std::make_shared<Var>()), add2_y_(std::make_shared<Var>()),
add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()))) {} add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()))) {}
~LambNextRightRule() override = default; ~LambNextRightRule() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;

View File

@ -58,7 +58,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
const BaseRef MulAddFusion::DefinePattern() const { const BaseRef MulAddFusion::DefinePattern() const {
VarPtr x = std::make_shared<Var>(); VarPtr x = std::make_shared<Var>();
VarPtr y = std::make_shared<Var>(); VarPtr y = std::make_shared<Var>();
VectorRef pattern({prim::kPrimTensorAdd, x, y}); VectorRef pattern({prim::kPrimAdd, x, y});
return pattern; return pattern;
} }

View File

@ -51,13 +51,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace } // namespace
const BaseRef AdamFusion::DefinePattern() const { const BaseRef AdamFusion::DefinePattern() const {
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef next_m = VectorRef(
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef next_v = VectorRef next_v =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update = VectorRef( VectorRef update =
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});

View File

@ -51,14 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace } // namespace
const BaseRef AdamWeightDecayFusion::DefinePattern() const { const BaseRef AdamWeightDecayFusion::DefinePattern() const {
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef next_m = VectorRef(
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef next_v = VectorRef next_v =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update = VectorRef( VectorRef update =
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});

View File

@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace } // namespace
const BaseRef AddReluGradV2Fusion::DefinePattern() const { const BaseRef AddReluGradV2Fusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_}); VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimAdd, x1_, x2_}), mask_});
return relu_grad; return relu_grad;
} }

View File

@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace } // namespace
const BaseRef AddReluV2Fusion::DefinePattern() const { const BaseRef AddReluV2Fusion::DefinePattern() const {
VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})}); VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimAdd, x1_, x2_})});
return relu; return relu;
} }

View File

@ -30,7 +30,7 @@ namespace opt {
const BaseRef BatchNormAddReluFusion::DefinePattern() const { const BaseRef BatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
VectorRef tensor_add = VectorRef({prim::kPrimTensorAdd, tuple_get_item, z_}); VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
return relu; return relu;
} }

View File

@ -42,7 +42,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(B); MS_EXCEPTION_IF_NULL(B);
int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n"); int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
if (num_input == 2) { if (num_input == 2) {
auto prim = std::make_shared<Primitive>(prim::kPrimTensorAdd->name()); auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
auto add_new = graph->NewCNode(inputs); auto add_new = graph->NewCNode(inputs);

View File

@ -47,7 +47,7 @@ AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_
} }
AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimTensorAdd)) { if (!IsPrimitiveCNode(node, prim::kPrimAdd)) {
return nullptr; return nullptr;
} }
PatternNode<AnfNodePtr> x, y, z; PatternNode<AnfNodePtr> x, y, z;
@ -57,13 +57,13 @@ AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
PConstant<AnfNodePtr> any_const_2(node); PConstant<AnfNodePtr> any_const_2(node);
auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr { auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node);
return new_cnode; return new_cnode;
}; };
auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr { auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr {
auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node)); auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node);
return new_cnode; return new_cnode;
}; };
// A + 0 = A // A + 0 = A
@ -88,7 +88,7 @@ AnfNodePtr SimplifySub(const AnfNodePtr &node) {
PConstant<AnfNodePtr> any_const(node); PConstant<AnfNodePtr> any_const(node);
auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr { auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr {
auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg); auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node);
return new_cnode; return new_cnode;
}; };
// A - 0 = A // A - 0 = A
@ -269,7 +269,7 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
return new_cnode; return new_cnode;
}; };
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
return new_cnode; return new_cnode;
}; };

View File

@ -741,14 +741,14 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList() { std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D #if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = { std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast, prim::kPrimRealDiv}; prim::kPrimCast, prim::kPrimRealDiv};
#elif ENABLE_GPU #elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = { std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,

View File

@ -52,7 +52,7 @@ namespace opt {
namespace irpass { namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});

View File

@ -272,7 +272,7 @@ class AddNEliminater : public AnfVisitor {
if (tuple_inputs.size() == 3) { if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]}; tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs)); mng->Replace(node, func_graph->NewCNode(new_xs));
@ -299,7 +299,7 @@ class AddNEliminater : public AnfVisitor {
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode( auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
auto new_add = auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add); (void)mng->Replace(node, new_add);

View File

@ -860,7 +860,7 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
if (ops[iter_ops]->type() == L2_NORMALIZE) { if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra); return PrepareL2Normalize(ops, iter_ops, basic_stra);
} }
if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL || if (ops[iter_ops]->type() == ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
ops[iter_ops]->type() == DIV) { ops[iter_ops]->type() == DIV) {
return CheckBroadcast(ops, iter_ops, basic_stra); return CheckBroadcast(ops, iter_ops, basic_stra);
} }

View File

@ -78,7 +78,7 @@ const std::map<std::string, OperatorType> DictOpType{
// Elm-wise OP // Elm-wise OP
{TRANSPOSE, OperatorType::kRecElmWiseOp}, {TRANSPOSE, OperatorType::kRecElmWiseOp},
{L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {ADD, OperatorType::kRecElmWiseOp},
{TENSOR_DOT, OperatorType::kRecElmWiseOp}, {TENSOR_DOT, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
{MUL, OperatorType::kRecElmWiseOp}, {MUL, OperatorType::kRecElmWiseOp},

View File

@ -86,7 +86,7 @@ REGISTER(LogSoftmaxInfo);
REGISTER(ActivationInfo); REGISTER(ActivationInfo);
REGISTER(SoftmaxCrossEntropyWithLogitsInfo); REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
REGISTER(SubInfo); REGISTER(SubInfo);
REGISTER(TensorAddInfo); REGISTER(AddInfo);
REGISTER(BiasAddInfo); REGISTER(BiasAddInfo);
REGISTER(MulInfo); REGISTER(MulInfo);
REGISTER(DivInfo); REGISTER(DivInfo);

View File

@ -60,12 +60,11 @@ class SubInfo : public ArithmeticBase {
~SubInfo() override = default; ~SubInfo() override = default;
}; };
class TensorAddInfo : public ArithmeticBase { class AddInfo : public ArithmeticBase {
public: public:
TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
~TensorAddInfo() override = default; ~AddInfo() override = default;
}; };
class MulInfo : public ArithmeticBase { class MulInfo : public ArithmeticBase {

View File

@ -191,7 +191,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)});
auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});
auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)});
Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_);

View File

@ -200,7 +200,7 @@ constexpr char MAXPOOLV2[] = "MaxPoolV2";
constexpr char L2_NORMALIZE[] = "L2Normalize"; constexpr char L2_NORMALIZE[] = "L2Normalize";
constexpr char TRANSPOSE[] = "Transpose"; constexpr char TRANSPOSE[] = "Transpose";
constexpr char RESHAPE[] = "Reshape"; constexpr char RESHAPE[] = "Reshape";
constexpr char TENSOR_ADD[] = "TensorAdd"; constexpr char ADD[] = "Add";
constexpr char BIAS_ADD[] = "BiasAdd"; constexpr char BIAS_ADD[] = "BiasAdd";
constexpr char SUB[] = "Sub"; constexpr char SUB[] = "Sub";
constexpr char MUL[] = "Mul"; constexpr char MUL[] = "Mul";
@ -315,7 +315,6 @@ constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";
constexpr char DROPOUT[] = "Dropout"; constexpr char DROPOUT[] = "Dropout";
constexpr char KStridedSlice[] = "StridedSlice"; constexpr char KStridedSlice[] = "StridedSlice";
constexpr char UNIQUE[] = "Unique"; constexpr char UNIQUE[] = "Unique";

View File

@ -151,7 +151,7 @@ bool IsSplittableOperator(const std::string &op_name) {
// clang-format off // clang-format off
static const std::set<std::string> splittable_op = static const std::set<std::string> splittable_op =
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,

View File

@ -165,7 +165,7 @@ class OpNameInfo {
#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ #define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \
OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); }
OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(Add, Add, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
@ -257,7 +257,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Add)());
fn(OP_CONVERT_FUNCTION_NAME(Mul)()); fn(OP_CONVERT_FUNCTION_NAME(Mul)());
fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); fn(OP_CONVERT_FUNCTION_NAME(ReLU)());

View File

@ -29,7 +29,7 @@ REG_ADPT_DESC(StateSetItem, prim::kPrimStateSetItem->name(), ADPT_DESC(Assign))
INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Add) = EMPTY_ATTR_MAP; ATTR_MAP(Add) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Add, prim::kPrimTensorAdd->name(), REG_ADPT_DESC(Add, prim::kPrimAdd->name(),
std::make_shared<OpAdapterDesc>( std::make_shared<OpAdapterDesc>(
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})), std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})))) std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))

View File

@ -215,7 +215,7 @@ constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kTensorAddOpName = "Add";
constexpr auto kCastOpName = "Cast"; constexpr auto kCastOpName = "Cast";
constexpr auto kGreaterEqualOpName = "GreaterEqual"; constexpr auto kGreaterEqualOpName = "GreaterEqual";
constexpr auto kAbsOpName = "Abs"; constexpr auto kAbsOpName = "Abs";

View File

@ -46,7 +46,7 @@ class ExportToQuantInferNetwork:
Returns: Returns:
Cell, Infer network. Cell, Infer network.
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = Validator.check_isinstance('network', network, (nn.Cell,)) network = Validator.check_isinstance('network', network, (nn.Cell,))
@ -225,7 +225,7 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
Returns: Returns:
Cell, Infer network. Cell, Infer network.
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir) super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir)

View File

@ -173,7 +173,7 @@ class QuantizationAwareTraining(Quantizer):
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
>>> net_qat = quantizer.quantize(net) >>> net_qat = quantizer.quantize(net)
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
def __init__(self, def __init__(self,
bn_fold=True, bn_fold=True,

View File

@ -91,8 +91,8 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive
AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);

View File

@ -60,8 +60,8 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
return out->Broaden(); return out->Broaden();
} }
AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors. // Inputs: two tensors.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2); CheckArgsSize(op_name, args_spec_list, 2);

View File

@ -37,7 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMul, {InferImplMul, true}}, {prim::kPrimMul, {InferImplMul, true}},
{prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, {prim::kPrimAdd, {InferImplAdd, true}},
{prim::kPrimSquare, {InferImplSquare, true}}, {prim::kPrimSquare, {InferImplSquare, true}},
{prim::kPrimSqrt, {InferImplSqrt, true}}, {prim::kPrimSqrt, {InferImplSqrt, true}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},

View File

@ -236,7 +236,7 @@ inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primiti
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
// Maths // Maths
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add");
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");

View File

@ -49,6 +49,6 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimTensorAdd, AddInfer); REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
REGISTER_PRIMITIVE_C(kNameAdd, Add); REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace mindspore } // namespace mindspore

View File

@ -989,7 +989,7 @@ class PConstant : public PBase<PConstant<T> > {
} }
// Arithmetic operations // Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false); BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);

View File

@ -225,7 +225,7 @@ class LambNextMV(GraphKernel):
def __init__(self): def __init__(self):
super(LambNextMV, self).__init__() super(LambNextMV, self).__init__()
self.mul = P.Mul() self.mul = P.Mul()
self.add = P.TensorAdd() self.add = P.Add()
self.div = P.RealDiv() self.div = P.RealDiv()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.rsqrt = P.Rsqrt() self.rsqrt = P.Rsqrt()

View File

@ -651,7 +651,7 @@ class LogSigmoid(Cell):
super(LogSigmoid, self).__init__() super(LogSigmoid, self).__init__()
self.mul = P.Mul() self.mul = P.Mul()
self.exp = P.Exp() self.exp = P.Exp()
self.add = P.TensorAdd() self.add = P.Add()
self.rec = P.Reciprocal() self.rec = P.Reciprocal()
self.log = P.Log() self.log = P.Log()

View File

@ -441,13 +441,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
self.mul = P.Mul() self.mul = P.Mul()
self.inf_mask_mul = P.Mul() self.inf_mask_mul = P.Mul()
self.bias_add = P.TensorAdd() self.bias_add = P.Add()
self.inf_add = P.TensorAdd() self.inf_add = P.Add()
self.merge_op = None self.merge_op = None
self.count_op = P.UnsortedSegmentSum() self.count_op = P.UnsortedSegmentSum()
self.abs = P.Abs() self.abs = P.Abs()
self.equal = P.Equal() self.equal = P.Equal()
self.add = P.TensorAdd() self.add = P.Add()
self.cast = P.Cast() self.cast = P.Cast()
self.div_no_nan = P.DivNoNan() self.div_no_nan = P.DivNoNan()
self.expand = P.ExpandDims() self.expand = P.ExpandDims()

View File

@ -99,8 +99,8 @@ class BatchNormFoldCell(Cell):
else: else:
batch_mean = P.ZerosLike()(variance) batch_mean = P.ZerosLike()(variance)
batch_std = P.OnesLike()(variance) batch_std = P.OnesLike()(variance)
running_mean = P.TensorAdd()(mean, 0.) running_mean = P.Add()(mean, 0.)
running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon)) running_std = P.Sqrt()(P.Add()(variance, self.epsilon))
return batch_mean, batch_std, running_mean, running_std return batch_mean, batch_std, running_mean, running_std
@ -559,7 +559,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
return s return s
def construct(self, x): def construct(self, x):
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps))
scale_factor = self.gamma / running_std scale_factor = self.gamma / running_std
if self.channel_axis: if self.channel_axis:
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
@ -1236,7 +1236,7 @@ class TensorAddQuant(Cell):
ema=True, ema=True,
ema_decay=ema_decay, ema_decay=ema_decay,
quant_dtype=quant_dtype) quant_dtype=quant_dtype)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x1, x2): def construct(self, x1, x2):
x = self.add(x1, x2) x = self.add(x1, x2)

View File

@ -155,9 +155,9 @@ def bprop_batchmatmul(self):
return bprop return bprop
@bprop_getters.register(P.TensorAdd) @bprop_getters.register(P.Add)
def get_bprop_tensor_add(self): def get_bprop_tensor_add(self):
"""Grad definition for `TensorAdd` operation.""" """Grad definition for `Add` operation."""
def bprop(x, y, out, dout): def bprop(x, y, out, dout):
return binop_grad_common(x, y, dout, dout) return binop_grad_common(x, y, dout, dout)

View File

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""TensorAdd op""" """Add op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("TensorAdd") \ op_info = AkgAscendRegOp("Add") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
@ -38,5 +38,5 @@ op_info = AkgAscendRegOp("TensorAdd") \
@op_info_register(op_info) @op_info_register(op_info)
def _add_akg(): def _add_akg():
"""TensorAdd Akg register""" """Add Akg register"""
return return

View File

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""TensorAdd op""" """Add op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_add_op_info = TBERegOp("TensorAdd") \ tensor_add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("add.so") \ .binfile_name("add.so") \

View File

@ -16,7 +16,7 @@
"""TensorAdd op""" """TensorAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_add_op_info = TBERegOp("TensorAdd") \ tensor_add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("add.so") \ .binfile_name("add.so") \

View File

@ -395,7 +395,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
>>> from mindspore.ops import Primitive, operations as P >>> from mindspore.ops import Primitive, operations as P
>>> from mindspore import dtype as mstype >>> from mindspore import dtype as mstype
>>> >>>
>>> tensor_add = P.TensorAdd() >>> tensor_add = P.Add()
>>> add = MultitypeFuncGraph('add') >>> add = MultitypeFuncGraph('add')
>>> @add.register("Number", "Number") >>> @add.register("Number", "Number")
... def add_scala(x, y): ... def add_scala(x, y):

View File

@ -51,7 +51,7 @@ merge = P.Merge()
geswitch = P.GeSwitch() geswitch = P.GeSwitch()
addn = P.AddN() addn = P.AddN()
absolute = P.Abs() absolute = P.Abs()
tensor_add = P.TensorAdd() tensor_add = P.Add()
neg_tensor = P.Neg() neg_tensor = P.Neg()
tensor_lt = P.Less() tensor_lt = P.Less()
tensor_le = P.LessEqual() tensor_le = P.LessEqual()

View File

@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse) MatrixInverse)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
@ -102,6 +102,7 @@ __all__ = [
'Sort', 'Sort',
'EditDistance', 'EditDistance',
'CropAndResize', 'CropAndResize',
'Add',
'TensorAdd', 'TensorAdd',
'Argmax', 'Argmax',
'Argmin', 'Argmin',

View File

@ -106,7 +106,7 @@ class GeSwitch(PrimitiveWithInfer):
... def __init__(self): ... def __init__(self):
... super(Net, self).__init__() ... super(Net, self).__init__()
... self.square = ops.Square() ... self.square = ops.Square()
... self.add = ops.TensorAdd() ... self.add = ops.Add()
... self.value = Tensor(np.full((1), 3), mindspore.float32) ... self.value = Tensor(np.full((1), 3), mindspore.float32)
... self.switch = ops.GeSwitch() ... self.switch = ops.GeSwitch()
... self.merge = ops.Merge() ... self.merge = ops.Merge()

View File

@ -66,7 +66,7 @@ class ScalarSummary(PrimitiveWithInfer):
... def __init__(self,): ... def __init__(self,):
... super(SummaryDemo, self).__init__() ... super(SummaryDemo, self).__init__()
... self.summary = ops.ScalarSummary() ... self.summary = ops.ScalarSummary()
... self.add = ops.TensorAdd() ... self.add = ops.Add()
... ...
... def construct(self, x, y): ... def construct(self, x, y):
... name = "x" ... name = "x"
@ -149,7 +149,7 @@ class TensorSummary(PrimitiveWithInfer):
... def __init__(self,): ... def __init__(self,):
... super(SummaryDemo, self).__init__() ... super(SummaryDemo, self).__init__()
... self.summary = ops.TensorSummary() ... self.summary = ops.TensorSummary()
... self.add = ops.TensorAdd() ... self.add = ops.Add()
... ...
... def construct(self, x, y): ... def construct(self, x, y):
... x = self.add(x, y) ... x = self.add(x, y)
@ -191,7 +191,7 @@ class HistogramSummary(PrimitiveWithInfer):
... def __init__(self,): ... def __init__(self,):
... super(SummaryDemo, self).__init__() ... super(SummaryDemo, self).__init__()
... self.summary = ops.HistogramSummary() ... self.summary = ops.HistogramSummary()
... self.add = ops.TensorAdd() ... self.add = ops.Add()
... ...
... def construct(self, x, y): ... def construct(self, x, y):
... x = self.add(x, y) ... x = self.add(x, y)
@ -409,7 +409,7 @@ class Assert(PrimitiveWithInfer):
... def __init__(self): ... def __init__(self):
... super(AssertDemo, self).__init__() ... super(AssertDemo, self).__init__()
... self.assert1 = ops.Assert(summarize=10) ... self.assert1 = ops.Assert(summarize=10)
... self.add = ops.TensorAdd() ... self.add = ops.Add()
... ...
... def construct(self, x, y): ... def construct(self, x, y):
... data = self.add(x, y) ... data = self.add(x, y)

View File

@ -18,6 +18,7 @@
import copy import copy
import numpy as np import numpy as np
from mindspore import log as logger
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@ -114,7 +115,7 @@ class _BitwiseBinaryOp(_MathBinaryOp):
return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name) return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name)
class TensorAdd(_MathBinaryOp): class Add(_MathBinaryOp):
r""" r"""
Adds two input tensors element-wise. Adds two input tensors element-wise.
@ -143,7 +144,7 @@ class TensorAdd(_MathBinaryOp):
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> add = ops.TensorAdd() >>> add = ops.Add()
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32)) >>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32))
>>> output = add(input_x, input_y) >>> output = add(input_x, input_y)
@ -160,6 +161,10 @@ class TensorAdd(_MathBinaryOp):
return Tensor(out) return Tensor(out)
return None return None
def TensorAdd():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
return Add()
class AssignAdd(PrimitiveWithInfer): class AssignAdd(PrimitiveWithInfer):
""" """

View File

@ -16,7 +16,7 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import Add
from src.var_init import KaimingNormal from src.var_init import KaimingNormal
@ -91,7 +91,7 @@ class InvertedResidual(nn.Cell):
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = TensorAdd() self.add = Add()
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, x): def construct(self, x):

View File

@ -198,7 +198,7 @@ class BasicBlock(nn.Cell):
self.bn2 = ms_fused_bn(planes) self.bn2 = ms_fused_bn(planes)
self.relu = P.ReLU() self.relu = P.ReLU()
self.downsample = downsample self.downsample = downsample
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
residual = x residual = x

View File

@ -102,7 +102,7 @@ class Bottleneck(nn.Cell):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.downsample = downsample self.downsample = downsample
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -222,7 +222,7 @@ class ResidualBlockUsing(nn.Cell):
self.bn_down_sample = self.bn_down_sample.set_train() self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update: if not weights_update:
self.conv_down_sample.weight.requires_grad = False self.conv_down_sample.weight.requires_grad = False
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -218,7 +218,7 @@ class ResidualBlockUsing(nn.Cell):
self.bn_down_sample = self.bn_down_sample.set_train() self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update: if not weights_update:
self.conv_down_sample.weight.requires_grad = False self.conv_down_sample.weight.requires_grad = False
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -16,7 +16,7 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import Add
from mindspore import Tensor from mindspore import Tensor
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] __all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
@ -129,7 +129,7 @@ class InvertedResidual(nn.Cell):
nn.BatchNorm2d(oup), nn.BatchNorm2d(oup),
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = TensorAdd() self.add = Add()
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, x): def construct(self, x):

View File

@ -120,7 +120,7 @@ class InvertedResidual(nn.Cell):
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
out = self.conv(x) out = self.conv(x)

View File

@ -123,7 +123,7 @@ class InvertedResidual(nn.Cell):
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
out = self.conv(x) out = self.conv(x)

View File

@ -197,7 +197,7 @@ class ResUnit(nn.Cell):
padding=0, act_type=act_type, use_act=False) padding=0, act_type=act_type, use_act=False)
if num_in != num_out or stride != 1: if num_in != num_out or stride != 1:
self.use_short_cut_conv = False self.use_short_cut_conv = False
self.add = P.TensorAdd() if self.use_short_cut_conv else None self.add = P.Add() if self.use_short_cut_conv else None
def construct(self, x): def construct(self, x):
"""construct""" """construct"""

View File

@ -49,7 +49,7 @@ class DiceLoss(_Loss):
self.logical_or = P.LogicalOr() self.logical_or = P.LogicalOr()
self.equal = P.Equal() self.equal = P.Equal()
self.zeros_like = P.ZerosLike() self.zeros_like = P.ZerosLike()
self.add = P.TensorAdd() self.add = P.Add()
self.gather = P.Gather() self.gather = P.Gather()
def ohem_batch(self, scores, gt_texts, training_masks): def ohem_batch(self, scores, gt_texts, training_masks):

View File

@ -61,7 +61,7 @@ class ResidualBlock(nn.Cell):
kernel_size=1, stride=stride) kernel_size=1, stride=stride)
self.bn_down_sample = _bn(out_channels, momentum=momentum) self.bn_down_sample = _bn(out_channels, momentum=momentum)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -152,7 +152,7 @@ class ResidualBlock(nn.Cell):
else: else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)]) use_se=self.use_se), _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -119,7 +119,7 @@ class ResidualBlock(nn.Cell):
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -85,7 +85,7 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
kernel_size=1, stride=stride, kernel_size=1, stride=stride,
pad_mode='same', padding=0, has_bn=True, activation='relu') pad_mode='same', padding=0, has_bn=True, activation='relu')
self.add = P.TensorAdd() self.add = P.Add()
self.relu = P.ReLU() self.relu = P.ReLU()
def construct(self, x): def construct(self, x):

View File

@ -215,7 +215,7 @@ class ResidualBlock(nn.Cell):
frequency=frequency, frequency=frequency,
batch_size=batch_size), batch_size=batch_size),
_bn(out_channel)]) _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -333,7 +333,7 @@ class Dense_Thor_GPU(Cell):
self.gather = P.Gather() self.gather = P.Gather()
self.freq = Tensor(frequency, mstype.int32) self.freq = Tensor(frequency, mstype.int32)
self.axis = 0 self.axis = 0
self.add = P.TensorAdd() self.add = P.Add()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.cholesky = P.CholeskyTrsm(split_dim=split_dim) self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True) self.vector_matmul = P.BatchMatMul(transpose_a=True)
@ -690,7 +690,7 @@ class Dense_Thor(Cell):
self.exp = P.Exp() self.exp = P.Exp()
self.dampingA = Tensor(np.identity(2048), mstype.float32) self.dampingA = Tensor(np.identity(2048), mstype.float32)
self.dampingG = Tensor(np.identity(1024), mstype.float32) self.dampingG = Tensor(np.identity(1024), mstype.float32)
self.add = P.TensorAdd() self.add = P.Add()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.getG = P.InsertGradientOf(self.save_gradient) self.getG = P.InsertGradientOf(self.save_gradient)

View File

@ -16,7 +16,7 @@
ResNet based ResNext ResNet based ResNext
""" """
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd, Split, Concat from mindspore.ops.operations import Add, Split, Concat
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
@ -105,7 +105,7 @@ class BasicBlock(nn.Cell):
self.down_sample = down_sample self.down_sample = down_sample
self.down_sample_flag = True self.down_sample_flag = True
self.add = TensorAdd() self.add = Add()
def construct(self, x): def construct(self, x):
identity = x identity = x
@ -176,7 +176,7 @@ class Bottleneck(nn.Cell):
self.down_sample_flag = True self.down_sample_flag = True
self.cast = P.Cast() self.cast = P.Cast()
self.add = TensorAdd() self.add = Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -95,7 +95,7 @@ class ResidualBlock(nn.Cell):
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
_bn(out_channel)]) _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -68,7 +68,7 @@ class ShuffleV1Block(nn.Cell):
outputs = oup outputs = oup
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.add = P.TensorAdd() self.add = P.Add()
self.concat = P.Concat(1) self.concat = P.Concat(1)
self.shape = P.Shape() self.shape = P.Shape()
self.transpose = P.Transpose() self.transpose = P.Transpose()

View File

@ -170,7 +170,7 @@ class SqueezeNet_Residual(nn.Cell):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
self.add = P.TensorAdd() self.add = P.Add()
self.dropout = nn.Dropout(keep_prob=0.5) self.dropout = nn.Dropout(keep_prob=0.5)
self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()

View File

@ -133,7 +133,7 @@ class InvertedResidual(nn.Cell):
_bn(oup), _bn(oup),
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = P.TensorAdd() self.add = P.Add()
self.cast = P.Cast() self.cast = P.Cast()
self.last_relu = last_relu self.last_relu = last_relu
self.relu = nn.ReLU6() self.relu = nn.ReLU6()

View File

@ -68,7 +68,7 @@ class Block(nn.Cell):
if strides != 1: if strides != 1:
rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
self.rep = nn.SequentialCell(*rep) self.rep = nn.SequentialCell(*rep)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, inp): def construct(self, inp):
x = self.rep(inp) x = self.rep(inp)

View File

@ -62,7 +62,7 @@ class ResidualBlock(nn.Cell):
out_chls = out_channels//2 out_chls = out_channels//2
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -59,7 +59,7 @@ class ResidualBlock(nn.Cell):
out_chls = out_channels//2 out_chls = out_channels//2
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -107,7 +107,7 @@ class BasicBlock(nn.Cell):
self.downsample = (in_channels != out_channels) self.downsample = (in_channels != out_channels)
if self.downsample: if self.downsample:
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride) self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x

View File

@ -76,7 +76,7 @@ class ResidualBlock(nn.Cell):
out_chls = out_channels out_chls = out_channels
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, x): def construct(self, x):
identity = x identity = x
@ -111,7 +111,7 @@ class CspDarkNet53(nn.Cell):
self.outchannel = 1024 self.outchannel = 1024
self.detect = detect self.detect = detect
self.concat = P.Concat(axis=1) self.concat = P.Concat(axis=1)
self.add = P.TensorAdd() self.add = P.Add()
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1) self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2) self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)

View File

@ -188,7 +188,7 @@ class EmbeddingPostprocessor(nn.Cell):
use_one_hot=False) use_one_hot=False)
self.layernorm = nn.LayerNorm((embedding_size,)) self.layernorm = nn.LayerNorm((embedding_size,))
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, token_type_ids, word_embeddings): def construct(self, token_type_ids, word_embeddings):
"""Postprocessors apply positional and token type embeddings to word embeddings.""" """Postprocessors apply positional and token type embeddings to word embeddings."""
@ -226,7 +226,7 @@ class BertOutput(nn.Cell):
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob) self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob self.dropout_prob = dropout_prob
self.add = P.TensorAdd() self.add = P.Add()
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast() self.cast = P.Cast()
@ -444,7 +444,7 @@ class BertAttention(nn.Cell):
if self.has_attention_mask: if self.has_attention_mask:
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.sub = P.Sub() self.sub = P.Sub()
self.add = P.TensorAdd() self.add = P.Add()
self.cast = P.Cast() self.cast = P.Cast()
self.get_dtype = P.DType() self.get_dtype = P.DType()
if do_return_2d_tensor: if do_return_2d_tensor:

View File

@ -227,7 +227,7 @@ class EmbeddingPostprocessor(nn.Cell):
frequency=frequency) frequency=frequency)
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.layernorm = nn.LayerNorm((embedding_size,)) self.layernorm = nn.LayerNorm((embedding_size,))
self.add = P.TensorAdd() self.add = P.Add()
def construct(self, token_type_ids, word_embeddings): def construct(self, token_type_ids, word_embeddings):
"""construct of EmbeddingPostprocessor""" """construct of EmbeddingPostprocessor"""
@ -275,7 +275,7 @@ class BertOutput(nn.Cell):
batch_size=batch_size).to_float(compute_type) batch_size=batch_size).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob) self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob self.dropout_prob = dropout_prob
self.add = P.TensorAdd() self.add = P.Add()
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast() self.cast = P.Cast()
@ -522,7 +522,7 @@ class BertAttention(nn.Cell):
if self.has_attention_mask: if self.has_attention_mask:
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.sub = P.Sub() self.sub = P.Sub()
self.add = P.TensorAdd() self.add = P.Add()
self.cast = P.Cast() self.cast = P.Cast()
self.get_dtype = P.DType() self.get_dtype = P.DType()
if do_return_2d_tensor: if do_return_2d_tensor:

View File

@ -35,7 +35,7 @@ class LengthPenalty(nn.Cell):
def __init__(self, weight=1.0, compute_type=mstype.float32): def __init__(self, weight=1.0, compute_type=mstype.float32):
super(LengthPenalty, self).__init__() super(LengthPenalty, self).__init__()
self.weight = weight self.weight = weight
self.add = P.TensorAdd() self.add = P.Add()
self.pow = P.Pow() self.pow = P.Pow()
self.div = P.RealDiv() self.div = P.RealDiv()
self.five = Tensor(5.0, mstype.float32) self.five = Tensor(5.0, mstype.float32)
@ -188,7 +188,7 @@ class BeamSearchDecoder(nn.Cell):
self.decoder = decoder self.decoder = decoder
self.is_using_while = is_using_while self.is_using_while = is_using_while
self.add = P.TensorAdd() self.add = P.Add()
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape_flat = (-1,) self.shape_flat = (-1,)

View File

@ -36,7 +36,7 @@ class LengthPenalty(nn.Cell):
super(LengthPenalty, self).__init__() super(LengthPenalty, self).__init__()
self.weight = weight self.weight = weight
self.add = P.TensorAdd() self.add = P.Add()
self.pow = P.Pow() self.pow = P.Pow()
self.div = P.RealDiv() self.div = P.RealDiv()
@ -178,7 +178,7 @@ class BeamSearchDecoder(nn.Cell):
self.decoder = decoder self.decoder = decoder
self.add = P.TensorAdd() self.add = P.Add()
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape_flat = (-1,) self.shape_flat = (-1,)

View File

@ -138,7 +138,7 @@ class MultiHeadAttention(nn.Cell):
if self.has_attention_mask: if self.has_attention_mask:
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.sub = P.Sub() self.sub = P.Sub()
self.add = P.TensorAdd() self.add = P.Add()
self.cast = P.Cast() self.cast = P.Cast()
self.get_dtype = P.DType() self.get_dtype = P.DType()

Some files were not shown because too many files have changed in this diff Show More