!44864 acc clip norm

Merge pull request !44864 from fangzehua/acc_clip_norm
This commit is contained in:
i-robot 2022-11-01 01:15:42 +00:00 committed by Gitee
commit 91a99f5f00
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 89 additions and 52 deletions

View File

@ -70,7 +70,6 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertUnusedTupleParaToMakeTuple>());
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
common_pm->AddPass(std::make_shared<ClipByNormFission>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
common_pm->AddPass(std::make_shared<AddTrainingAttr>());
common_pm->AddPass(std::make_shared<FlattenConcatFission>());

View File

@ -370,6 +370,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<AdamWeightDecayFission>());
ir_fusion_pm->AddPass(std::make_shared<ScaleGradFission>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormFission>());
ir_fusion_pm->AddPass(std::make_shared<LambFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());

View File

@ -197,6 +197,73 @@ void ClipByNormCpuKernelMod::InitSizeLists() {
auto output_size = float_type_size * SizeOf(output_shape_);
output_size = std::max(output_size, float_type_size);
(void)output_size_list_.emplace_back(output_size);
// Calculate l2norm iter
// Calculate transpose axes and stride
size_t j = 0;
size_t k = 0;
stride_ = 1;
size_t axis_size = axis_.size();
std::vector<size_t> axes(x_shape_.size());
for (size_t i = 0; i < x_dim_; ++i) {
if (j == axis_size || i != axis_[j]) {
axes[k] = i;
++k;
} else {
stride_ *= LongToSize(x_shape_[i]);
++j;
}
}
for (const auto &v : axis_) {
axes[k] = v;
++k;
}
// Calculate transpose shape
ShapeVector transpose_shape(x_shape_.size());
for (size_t i = 0; i < x_dim_; ++i) {
transpose_shape[i] = x_shape_[axes[i]];
}
auto l2_norm_out_ele = l2_norm_out_size / sizeof(float);
l2_norm_index_.clear();
l2_norm_index_.resize(l2_norm_out_ele * stride_);
TransposeIterator l2n_base_iter(std::move(transpose_shape), std::move(axes), x_shape_);
l2n_base_iter.SetPos(0);
for (size_t i = 0; i < l2_norm_index_.size(); ++i) {
l2_norm_index_[i] = l2n_base_iter.GetPos();
l2n_base_iter.GenNextPos();
}
// Cal Div iter
if (!x_shape_.empty()) {
BroadcastIterator div_base_iter(x_shape_, l2_norm_output_shape_, x_shape_);
div_base_iter.SetPos(0);
div_index1_.clear();
div_index2_.clear();
auto div_ele = x_size / x_type_size;
div_index1_.resize(div_ele);
div_index2_.resize(div_ele);
for (size_t i = 0; i < div_ele; ++i) {
div_index1_[i] = div_base_iter.GetInputPosA();
div_index2_[i] = div_base_iter.GetInputPosB();
div_base_iter.GenNextPos();
}
}
// Cal Mul iter
if (!x_shape_.empty()) {
BroadcastIterator mul_base_iter(x_shape_, clip_norm_shape_, output_shape_);
mul_base_iter.SetPos(0);
mul_index1_.clear();
mul_index2_.clear();
auto output_ele = output_size / sizeof(float);
mul_index1_.resize(output_ele);
mul_index2_.resize(output_ele);
for (size_t i = 0; i < output_ele; ++i) {
mul_index1_[i] = mul_base_iter.GetInputPosA();
mul_index2_[i] = mul_base_iter.GetInputPosB();
mul_base_iter.GenNextPos();
}
}
}
template <typename T, typename S>
@ -220,46 +287,17 @@ void ClipByNormCpuKernelMod::LaunchFunc(const std::vector<AddressPtr> &inputs, c
template <typename T>
void ClipByNormCpuKernelMod::L2NormLaunch(const T *x_addr, float *l2_norm_output_addr, size_t l2_norm_output_size) {
// Calculate transpose axes and stride
size_t j = 0;
size_t k = 0;
size_t stride = 1;
size_t axis_size = axis_.size();
std::vector<size_t> axes(x_shape_.size());
for (size_t i = 0; i < x_dim_; ++i) {
if (j == axis_size || i != axis_[j]) {
axes[k] = i;
++k;
} else {
stride *= LongToSize(x_shape_[i]);
++j;
}
}
for (const auto &v : axis_) {
axes[k] = v;
++k;
}
// Calculate transpose shape
ShapeVector transpose_shape(x_shape_.size());
for (size_t i = 0; i < x_dim_; ++i) {
transpose_shape[i] = x_shape_[axes[i]];
}
// Run `l2_norm(x)` calculation
TransposeIterator base_iter(std::move(transpose_shape), std::move(axes), x_shape_);
auto task = [this, &base_iter, &x_addr, &l2_norm_output_addr, &stride](size_t start, size_t end) {
auto iter = base_iter;
auto task = [&](size_t start, size_t end) {
float zero = static_cast<float>(0);
float temp = zero;
float denominator = zero;
iter.SetPos(start * stride);
for (size_t i = start; i < end; ++i) {
denominator = static_cast<float>(x_addr[iter.GetPos()]);
denominator = static_cast<float>(x_addr[l2_norm_index_[i]]);
denominator = denominator * denominator;
iter.GenNextPos();
for (size_t j = 1; j < stride; ++j) {
temp = static_cast<float>(x_addr[iter.GetPos()]);
for (size_t j = 1; j < stride_; ++j) {
temp = static_cast<float>(x_addr[l2_norm_index_[i + j]]);
denominator += (temp * temp);
iter.GenNextPos();
}
denominator = (denominator > epsilon_) ? denominator : epsilon_;
l2_norm_output_addr[i] = sqrt(denominator);
@ -276,15 +314,11 @@ void ClipByNormCpuKernelMod::DivLaunch(const T *x_addr, const float *l2_norm_out
div_output_addr[0] = static_cast<float>(x_addr[0]) / l2_norm_output_addr[0];
return;
}
BroadcastIterator broadcast_base_iter(x_shape_, l2_norm_output_shape_, x_shape_);
auto task = [this, &broadcast_base_iter, &x_addr, &l2_norm_output_addr, &div_output_addr](size_t start, size_t end) {
auto iter = broadcast_base_iter;
iter.SetPos(start);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float zero = static_cast<float>(0);
float dividend = static_cast<float>(x_addr[iter.GetInputPosA()]);
float divisor = l2_norm_output_addr[iter.GetInputPosB()];
iter.GenNextPos();
float dividend = static_cast<float>(x_addr[div_index1_[i]]);
float divisor = l2_norm_output_addr[div_index2_[i]];
if (divisor == zero) {
if (dividend == zero) {
div_output_addr[i] = std::numeric_limits<float>::quiet_NaN();
@ -301,7 +335,7 @@ void ClipByNormCpuKernelMod::DivLaunch(const T *x_addr, const float *l2_norm_out
div_output_addr[i] = dividend / divisor;
}
};
ParallelLaunchAutoSearch(task, div_output_size / sizeof(float), this, &parallel_search_info_);
ParallelLaunchAutoSearch(task, div_output_size / sizeof(float), this, &parallel_search_info_div_);
}
template <typename T, typename S>
@ -317,25 +351,20 @@ void ClipByNormCpuKernelMod::ClipNormMulAndCmpLaunch(const T *x_addr, const floa
}
return;
}
BroadcastIterator broadcast_base_iter(x_shape_, clip_norm_shape_, output_shape_);
auto task = [this, &broadcast_base_iter, &x_addr, &clip_norm_addr, &div_output_addr, &output_addr](size_t start,
size_t end) {
auto iter = broadcast_base_iter;
iter.SetPos(start);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float div_out = div_output_addr[iter.GetInputPosA()];
float clip_norm = static_cast<float>(clip_norm_addr[iter.GetInputPosB()]);
float div_out = div_output_addr[mul_index1_[i]];
float clip_norm = static_cast<float>(clip_norm_addr[mul_index2_[i]]);
float mul_output = clip_norm * div_out;
float x = static_cast<float>(x_addr[iter.GetInputPosA()]);
float x = static_cast<float>(x_addr[mul_index1_[i]]);
if (x * mul_output >= 0) {
output_addr[i] = (mul_output * mul_output) > (x * x) ? x : mul_output;
} else {
output_addr[i] = mul_output;
}
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size / sizeof(float), this, &parallel_search_info_);
ParallelLaunchAutoSearch(task, output_size / sizeof(float), this, &parallel_search_info_mul_);
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ClipByNorm, ClipByNormCpuKernelMod);

View File

@ -71,6 +71,14 @@ class ClipByNormCpuKernelMod : public NativeCpuKernelMod {
ShapeVector clip_norm_shape_;
ShapeVector l2_norm_output_shape_;
ShapeVector output_shape_;
size_t stride_ = 1;
std::vector<size_t> l2_norm_index_;
std::vector<size_t> div_index1_;
std::vector<size_t> div_index2_;
std::vector<size_t> mul_index1_;
std::vector<size_t> mul_index2_;
ParallelSearchInfo parallel_search_info_div_;
ParallelSearchInfo parallel_search_info_mul_;
};
} // namespace kernel
} // namespace mindspore