forked from mindspore-Ecosystem/mindspore
[feat][assistant] add new operator LayerNormGradGrad
This commit is contained in:
parent
1537eeb647
commit
920dfdae9c
|
@ -0,0 +1,259 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/layer_norm_grad_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 8;
|
||||
constexpr size_t kOutputSize = 3;
|
||||
constexpr size_t kIdx0 = 0;
|
||||
constexpr size_t kIdx3 = 3;
|
||||
constexpr size_t kIdx4 = 4;
|
||||
constexpr size_t kZero = 0;
|
||||
constexpr size_t kMemMaxLen = 1e8;
|
||||
} // namespace
|
||||
|
||||
void LayerNormGradGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIdx0);
|
||||
mean_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIdx3);
|
||||
g_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIdx4);
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
bool calc_inv_std(DATA_T *input_var, DATA_T *inv_std, size_t mean_num) {
|
||||
for (size_t i = 0; i < mean_num; i++) {
|
||||
if (input_var[i] <= DATA_T(0)) {
|
||||
return false;
|
||||
}
|
||||
inv_std[i] = DATA_T(1) / sqrt(input_var[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
bool shard_inner_mean(size_t start_idx, size_t end_idx, size_t g_num, DATA_T *sum1, DATA_T *sum2, DATA_T *sum3,
|
||||
DATA_T *sum4, DATA_T *inv_std, DATA_T *input_d_dx, DATA_T *input_dy, DATA_T *input_gamma,
|
||||
DATA_T *input_x, DATA_T *input_mean, DATA_T *x_hat, DATA_T *dy_gamma) {
|
||||
for (size_t i = start_idx; i < end_idx; i++) {
|
||||
if (g_num == 0) {
|
||||
return false;
|
||||
}
|
||||
size_t sum_idx = i / g_num;
|
||||
sum1[sum_idx] -= inv_std[sum_idx] * input_d_dx[i] / static_cast<DATA_T>(g_num);
|
||||
DATA_T cur_x_hat = (input_x[i] - input_mean[sum_idx]) * inv_std[sum_idx];
|
||||
x_hat[i] = cur_x_hat;
|
||||
sum2[sum_idx] -= cur_x_hat * inv_std[sum_idx] * input_d_dx[i] / static_cast<DATA_T>(g_num);
|
||||
size_t g_idx = i % g_num;
|
||||
DATA_T cur_dy_gamma = input_dy[i] * input_gamma[g_idx];
|
||||
dy_gamma[i] = cur_dy_gamma;
|
||||
sum3[sum_idx] += cur_dy_gamma / static_cast<DATA_T>(g_num);
|
||||
sum4[sum_idx] += cur_dy_gamma * cur_x_hat / static_cast<DATA_T>(g_num);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
bool shard_outer_mean(size_t start_idx, size_t end_idx, size_t g_num, DATA_T *sum2, DATA_T *sum3, DATA_T *sum4,
|
||||
DATA_T *sum5, DATA_T *sum6, DATA_T *sum7, DATA_T *part3, DATA_T *inv_std, DATA_T *input_d_dx,
|
||||
DATA_T *input_d_dg, DATA_T *x_hat, DATA_T *dy_gamma, DATA_T *input_dy, DATA_T *input_x,
|
||||
DATA_T *input_mean) {
|
||||
for (size_t i = start_idx; i < end_idx; i++) {
|
||||
if (g_num == 0) {
|
||||
return false;
|
||||
}
|
||||
size_t g_idx = i % g_num;
|
||||
size_t sum_idx = i / g_num;
|
||||
DATA_T part_sum1 = dy_gamma[i] - sum3[sum_idx] - x_hat[i] * sum4[sum_idx];
|
||||
DATA_T part_sum2 =
|
||||
dy_gamma[i] * sum2[sum_idx] - sum4[sum_idx] * input_d_dx[i] * inv_std[sum_idx] + input_dy[i] * input_d_dg[g_idx];
|
||||
sum5[sum_idx] += input_d_dx[i] * part_sum1 / static_cast<DATA_T>(g_num);
|
||||
sum6[sum_idx] += (input_x[i] - input_mean[sum_idx]) * part_sum2 / static_cast<DATA_T>(g_num);
|
||||
DATA_T cur_part3 = inv_std[sum_idx] * part_sum2;
|
||||
part3[i] = cur_part3;
|
||||
sum7[sum_idx] -= cur_part3 / static_cast<DATA_T>(g_num);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
bool shard_input_prop(size_t start_idx, size_t end_idx, size_t g_num, DATA_T *sum1, DATA_T *sum2, DATA_T *sum5,
|
||||
DATA_T *sum6, DATA_T *sum7, DATA_T *part3, DATA_T *inv_std, DATA_T *input_d_dx,
|
||||
DATA_T *input_gamma, DATA_T *input_d_dg, DATA_T *input_d_db, DATA_T *x_hat, DATA_T *output_sopd_x,
|
||||
DATA_T *output_sopd_dy) {
|
||||
for (size_t i = start_idx; i < end_idx; i++) {
|
||||
if (g_num == 0) {
|
||||
return false;
|
||||
}
|
||||
size_t g_idx = i % g_num;
|
||||
size_t sum_idx = i / g_num;
|
||||
DATA_T cur_part4 = -x_hat[i] * inv_std[sum_idx] * inv_std[sum_idx] * (sum5[sum_idx] + sum6[sum_idx]);
|
||||
output_sopd_x[i] = part3[i] + cur_part4 + sum7[sum_idx];
|
||||
DATA_T cur_part5 = input_gamma[g_idx] * input_d_dx[i] * inv_std[sum_idx];
|
||||
DATA_T cur_part6 = input_gamma[g_idx] * sum1[sum_idx];
|
||||
DATA_T cur_part7 = input_gamma[g_idx] * x_hat[i] * sum2[sum_idx];
|
||||
DATA_T cur_part8 = x_hat[i] * input_d_dg[g_idx];
|
||||
output_sopd_dy[i] = cur_part5 + cur_part6 + cur_part7 + cur_part8 + input_d_db[g_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
bool shard_param_prop(size_t start_idx, size_t end_idx, size_t g_num, DATA_T *sum1, DATA_T *sum2, DATA_T *inv_std,
|
||||
DATA_T *input_d_dx, DATA_T *x_hat, DATA_T *input_dy, DATA_T *output_sopd_g) {
|
||||
for (size_t i = start_idx; i < end_idx; i++) {
|
||||
if (g_num == 0) {
|
||||
return false;
|
||||
}
|
||||
size_t g_idx = i % g_num;
|
||||
size_t sum_idx = i / g_num;
|
||||
DATA_T cur_part9 = input_dy[i] * x_hat[i] * sum2[sum_idx];
|
||||
DATA_T cur_part10 = input_dy[i] * sum1[sum_idx];
|
||||
DATA_T cur_part11 = input_dy[i] * input_d_dx[i] * inv_std[sum_idx];
|
||||
output_sopd_g[g_idx] += cur_part9 + cur_part10 + cur_part11;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LayerNormGradGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputSize, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputSize, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'input_x' should be float16, float32 but got "
|
||||
<< dtype_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DATA_T>
|
||||
void LayerNormGradGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
// enter LayerNormGradGradCompute
|
||||
auto input_x = reinterpret_cast<DATA_T *>(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
auto input_dy = reinterpret_cast<DATA_T *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_dy);
|
||||
auto input_var = reinterpret_cast<DATA_T *>(inputs[2]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_var);
|
||||
auto input_mean = reinterpret_cast<DATA_T *>(inputs[3]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_mean);
|
||||
auto input_gamma = reinterpret_cast<DATA_T *>(inputs[4]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_gamma);
|
||||
auto input_d_dx = reinterpret_cast<DATA_T *>(inputs[5]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_d_dx);
|
||||
auto input_d_dg = reinterpret_cast<DATA_T *>(inputs[6]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_d_dg);
|
||||
auto input_d_db = reinterpret_cast<DATA_T *>(inputs[7]->addr);
|
||||
MS_EXCEPTION_IF_NULL(input_d_db);
|
||||
auto output_sopd_x = reinterpret_cast<DATA_T *>(outputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output_sopd_x);
|
||||
auto output_sopd_dy = reinterpret_cast<DATA_T *>(outputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output_sopd_dy);
|
||||
auto output_sopd_g = reinterpret_cast<DATA_T *>(outputs[2]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output_sopd_g);
|
||||
size_t num =
|
||||
static_cast<size_t>(std::accumulate(input_shape_.cbegin(), input_shape_.cend(), 1, std::multiplies<int64_t>{}));
|
||||
size_t g_num =
|
||||
static_cast<size_t>(std::accumulate(g_shape_.cbegin(), g_shape_.cend(), 1, std::multiplies<int64_t>{}));
|
||||
size_t mean_num =
|
||||
static_cast<size_t>(std::accumulate(mean_shape_.cbegin(), mean_shape_.cend(), 1, std::multiplies<int64_t>{}));
|
||||
auto inv_std = std::make_unique<DATA_T[]>(mean_num);
|
||||
if (num == 0 || num > kMemMaxLen || mean_num == 0 || mean_num > kMemMaxLen) {
|
||||
MS_EXCEPTION(ValueError) << "memory allocation failed";
|
||||
}
|
||||
if (calc_inv_std<DATA_T>(input_var, inv_std.get(), mean_num) != true) {
|
||||
MS_EXCEPTION(ValueError) << "For LayerNormGradGrad, variance must be positive.";
|
||||
}
|
||||
auto x_hat = std::make_unique<DATA_T[]>(num);
|
||||
auto dy_gamma = std::make_unique<DATA_T[]>(num);
|
||||
auto sum1 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum1.get(), mean_num, DATA_T(0));
|
||||
auto sum2 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum2.get(), mean_num, DATA_T(0));
|
||||
auto sum3 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum3.get(), mean_num, DATA_T(0));
|
||||
auto sum4 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum4.get(), mean_num, DATA_T(0));
|
||||
shard_inner_mean<DATA_T>(0, num, g_num, sum1.get(), sum2.get(), sum3.get(), sum4.get(), inv_std.get(), input_d_dx,
|
||||
input_dy, input_gamma, input_x, input_mean, x_hat.get(), dy_gamma.get());
|
||||
auto sum5 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum5.get(), mean_num, DATA_T(0));
|
||||
auto sum6 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum6.get(), mean_num, DATA_T(0));
|
||||
auto sum7 = std::make_unique<DATA_T[]>(mean_num);
|
||||
std::fill_n(sum7.get(), mean_num, DATA_T(0));
|
||||
auto part3 = std::make_unique<DATA_T[]>(num);
|
||||
shard_outer_mean<DATA_T>(0, num, g_num, sum2.get(), sum3.get(), sum4.get(), sum5.get(), sum6.get(), sum7.get(),
|
||||
part3.get(), inv_std.get(), input_d_dx, input_d_dg, x_hat.get(), dy_gamma.get(), input_dy,
|
||||
input_x, input_mean);
|
||||
shard_input_prop<DATA_T>(0, num, g_num, sum1.get(), sum2.get(), sum5.get(), sum6.get(), sum7.get(), part3.get(),
|
||||
inv_std.get(), input_d_dx, input_gamma, input_d_dg, input_d_db, x_hat.get(), output_sopd_x,
|
||||
output_sopd_dy);
|
||||
std::fill_n(output_sopd_g, g_num, DATA_T(0));
|
||||
shard_param_prop<DATA_T>(0, num, g_num, sum1.get(), sum2.get(), inv_std.get(), input_d_dx, x_hat.get(), input_dy,
|
||||
output_sopd_g);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> LayerNormGradGradCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LayerNormGradGrad, LayerNormGradGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYERNORMGRADGRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYERNORMGRADGRAD_CPU_KERNEL_H_
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LayerNormGradGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
LayerNormGradGradCpuKernelMod() = default;
|
||||
~LayerNormGradGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DATA_T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> g_shape_;
|
||||
std::vector<int64_t> mean_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYERNORMGRADGRAD_CPU_KERNEL_H_
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
"""LayerNormGradGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
layernorm_grad_grad_op_info = AiCPURegOp("LayerNormGradGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required")\
|
||||
.input(1, "dy", "required") \
|
||||
.input(2, "variance", "required")\
|
||||
.input(3, "mean", "required") \
|
||||
.input(4, "gamma", "required")\
|
||||
.input(5, "d_dx", "required") \
|
||||
.input(6, "d_dg", "required")\
|
||||
.input(7, "d_db", "required") \
|
||||
.output(0, "sopd_x", "required") \
|
||||
.output(1, "sopd_dy", "required") \
|
||||
.output(2, "sopd_gamma", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layernorm_grad_grad_op_info)
|
||||
def _layernorm_grad_grad_aicpu():
|
||||
"""LayerNormGradGrad aicpu register"""
|
||||
return
|
|
@ -1395,6 +1395,8 @@ class LayerNormGradGrad(Primitive):
|
|||
"""init"""
|
||||
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
|
||||
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'dy', 'variance', 'mean', 'gamma', 'd_dx', 'd_dg', 'd_db'],
|
||||
outputs=['sopd_x', 'sopd_dy', 'sopd_gamma'])
|
||||
|
||||
|
||||
class LogSoftmaxGrad(Primitive):
|
||||
|
|
|
@ -228,6 +228,10 @@ class InputOpNet(nn.Cell):
|
|||
x = self.op(x1, x2, x3, x4, x5, x6, x7)
|
||||
return x
|
||||
|
||||
def construct8_c0(self, x1, x2, x3, x4, x5, x6, x7, x8):
|
||||
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8)
|
||||
return x
|
||||
|
||||
def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
|
||||
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9)
|
||||
return x
|
||||
|
|
|
@ -2705,6 +2705,11 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
||||
'desc_bprop': [[2, 16], [16], [16]],
|
||||
'skip': ['backward']}),
|
||||
('LayerNormGradGrad', {
|
||||
'block': G.LayerNormGradGrad(),
|
||||
'desc_inputs': [[2, 16], [2, 16], [2], [2], [16], [2, 16], [16], [16]],
|
||||
'desc_bprop': [[2, 16], [2, 16], [16]],
|
||||
'skip': ['backward']}),
|
||||
('BatchNorm', {
|
||||
'block': P.BatchNorm(),
|
||||
'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
|
|
Loading…
Reference in New Issue