forked from mindspore-Ecosystem/mindspore
!17036 clean code cpu
From: @wanyiming Reviewed-by: @wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
b75b9e8d9c
|
@ -59,15 +59,15 @@ bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline T LogSumExp(T logprob1, T logprob2) {
|
||||
inline T LogSumExp(const T logprob1, const T logprob2) {
|
||||
T kLogZero_ = -std::numeric_limits<T>::infinity();
|
||||
if (logprob1 == kLogZero_) {
|
||||
return logprob2;
|
||||
} else if (logprob2 == kLogZero_) {
|
||||
return logprob1;
|
||||
} else {
|
||||
return (logprob1 > logprob2) ? logprob1 + log1p(exp(logprob2 - logprob1))
|
||||
: logprob2 + log1p(exp(logprob1 - logprob2));
|
||||
return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
|
||||
: logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,10 +79,10 @@ void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_b
|
|||
int T = (*log_alpha_b)[0].size();
|
||||
TT kLogZero_ = -std::numeric_limits<TT>::infinity();
|
||||
|
||||
(*log_alpha_b)[0][0] = log(y[blank_index_][0]);
|
||||
(*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
|
||||
auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
|
||||
if (label_with_blank.size() > 1) {
|
||||
(*log_alpha_b)[1][0] = log(y[label_0][0]);
|
||||
(*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
|
||||
}
|
||||
|
||||
for (int t = 1; t < T; ++t) {
|
||||
|
@ -105,7 +105,7 @@ void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_b
|
|||
}
|
||||
}
|
||||
|
||||
(*log_alpha_b)[u][t] = log(y[label_with_blank[u]][t]) + sum_log_alpha_b;
|
||||
(*log_alpha_b)[u][t] = static_cast<TT>(log(y[label_with_blank[u]][t])) + sum_log_alpha_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_bla
|
|||
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
|
||||
}
|
||||
for (size_t l = 0; l < L; ++l) {
|
||||
(*dy_b)[l][t] = y[l][t] - exp(prob_sum[l] - log_pzx);
|
||||
(*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -229,8 +229,9 @@ void InnerSoftMax(T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, co
|
|||
}
|
||||
|
||||
for (size_t c = 0; c < num_class; ++c) {
|
||||
sumCoeff += exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff);
|
||||
(*softmax_probs)[c][t] = exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff);
|
||||
sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
|
||||
(*softmax_probs)[c][t] =
|
||||
static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
|
||||
}
|
||||
|
||||
for (size_t c = 0; c < num_class; ++c) {
|
||||
|
@ -267,7 +268,7 @@ void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
|||
T kLogZero_ = -std::numeric_limits<T>::infinity();
|
||||
// check validation of sequence length
|
||||
for (size_t b = 0; b < batch_size_; ++b) {
|
||||
if (sequence_length_addr[b] < uint32_t(0)) {
|
||||
if (sequence_length_addr[b] == uint32_t(0)) {
|
||||
MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b];
|
||||
}
|
||||
|
||||
|
|
|
@ -120,21 +120,21 @@ void MirrorPadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
}
|
||||
extract_paddings(paddings_arg, padd_dim, paddings);
|
||||
// Create anchor points for non mirrored data inside new tensor
|
||||
int ap1_x = paddings[WIDTH + LEFT];
|
||||
int ap2_x = paddings[WIDTH + LEFT] + old_width - 1;
|
||||
int ap1_y = paddings[HEIGHT + TOP];
|
||||
int ap2_y = paddings[HEIGHT + TOP] + old_height - 1;
|
||||
int ap1_channel = paddings[CHANNEL + LEFT];
|
||||
int ap2_channel = paddings[CHANNEL + LEFT] + old_channel - 1;
|
||||
int ap1_batch = paddings[BATCH + LEFT];
|
||||
int ap2_batch = paddings[BATCH + LEFT] + old_batch - 1;
|
||||
int channels_new = old_channel + paddings[CHANNEL + LEFT] + paddings[CHANNEL + RIGHT];
|
||||
int ap1_x = paddings[WIDTH];
|
||||
int ap2_x = paddings[WIDTH] + old_width - 1;
|
||||
int ap1_y = paddings[HEIGHT];
|
||||
int ap2_y = paddings[HEIGHT] + old_height - 1;
|
||||
int ap1_channel = paddings[CHANNEL];
|
||||
int ap2_channel = paddings[CHANNEL] + old_channel - 1;
|
||||
int ap1_batch = paddings[BATCH];
|
||||
int ap2_batch = paddings[BATCH] + old_batch - 1;
|
||||
int channels_new = old_channel + paddings[CHANNEL] + paddings[CHANNEL + RIGHT];
|
||||
|
||||
for (size_t pos = 0; pos < output_size_; ++pos) {
|
||||
int block_num = (pos / padded_width) / padded_height;
|
||||
int block_num = (SizeToLong(pos) / padded_width) / padded_height;
|
||||
// cur position
|
||||
const int padded_x = pos % padded_width;
|
||||
const int padded_y = (pos / padded_width) % padded_height;
|
||||
const int padded_x = SizeToLong(pos) % padded_width;
|
||||
const int padded_y = (SizeToLong(pos) / padded_width) % padded_height;
|
||||
const int padded_channel = block_num % channels_new;
|
||||
const int padded_batch = block_num / channels_new;
|
||||
|
||||
|
@ -167,13 +167,13 @@ void MirrorPadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
}
|
||||
|
||||
// calculate equivalent block in input
|
||||
equiv_block_num = ((matchval_batch_index - paddings[BATCH + LEFT]) * old_channel) +
|
||||
(matchval_channel_index - paddings[CHANNEL + LEFT]);
|
||||
equiv_block_num =
|
||||
((matchval_batch_index - paddings[BATCH]) * old_channel) + (matchval_channel_index - paddings[CHANNEL]);
|
||||
|
||||
// copy data from equiv block and adjusted x and y values in unpadded tensor
|
||||
outputs_addr[pos] =
|
||||
inputs_addr[(equiv_block_num * old_height + matchval_y_index - paddings[HEIGHT + TOP]) * old_width +
|
||||
matchval_x_index - paddings[WIDTH + LEFT]];
|
||||
auto pos_index = (equiv_block_num * old_height + matchval_y_index - paddings[HEIGHT]) * old_width +
|
||||
matchval_x_index - paddings[WIDTH];
|
||||
outputs_addr[pos] = inputs_addr[pos_index];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ void MirrorPadGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
max_height = max_height + (2 * (max_height - 1));
|
||||
}
|
||||
|
||||
if (output_shape_[(output_shape_.size() - 2) + 0] > max_height ||
|
||||
if (output_shape_[(output_shape_.size() - 2)] > max_height ||
|
||||
output_shape_[(output_shape_.size() - 2) + 1] > max_width) {
|
||||
MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more DIMS";
|
||||
}
|
||||
|
@ -136,24 +136,23 @@ void MirrorPadGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const T *dy, T *interim_dy,
|
||||
const int dx_batches, const int dx_channels,
|
||||
const int dx_height, const int dx_width, const int dy_height,
|
||||
const int dy_width, const int padd_dim,
|
||||
const int64_t *paddings_arg, int mode, T *dx) {
|
||||
void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, T *interim_dy, const int dx_height,
|
||||
const int dx_width, const int dy_height, const int dy_width,
|
||||
const int padd_dim, const int64_t *paddings_arg, int mode,
|
||||
T *dx) {
|
||||
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers
|
||||
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) {
|
||||
paddings[i] = 0; // init all to 0
|
||||
}
|
||||
extract_paddings_(paddings_arg, padd_dim, paddings);
|
||||
int ap1_x = paddings[WIDTH + LEFT];
|
||||
int ap2_x = paddings[WIDTH + LEFT] + dx_width - 1;
|
||||
int ap1_y = paddings[HEIGHT + TOP];
|
||||
int ap2_y = paddings[HEIGHT + TOP] + dx_height - 1;
|
||||
int ap1_x = paddings[WIDTH];
|
||||
int ap2_x = paddings[WIDTH] + dx_width - 1;
|
||||
int ap1_y = paddings[HEIGHT];
|
||||
int ap2_y = paddings[HEIGHT] + dx_height - 1;
|
||||
for (size_t pos = 0; pos < size; ++pos) {
|
||||
int dx_block_num = (pos / dx_width) / dx_height;
|
||||
const int grad_x = (pos % dx_width) + paddings[WIDTH + LEFT];
|
||||
const int grad_y = ((pos / dx_width) % dx_height) + paddings[HEIGHT + TOP];
|
||||
int dx_block_num = (SizeToLong(pos) / dx_width) / dx_height;
|
||||
const int grad_x = (SizeToLong(pos) % dx_width) + paddings[WIDTH];
|
||||
const int grad_y = ((SizeToLong(pos) / dx_width) % dx_height) + paddings[HEIGHT];
|
||||
dx[pos] = interim_dy[(dx_block_num * dy_height + grad_y) * dy_width + grad_x];
|
||||
int x_dist_1 = (ap1_x - grad_x - mode);
|
||||
int y_dist_1 = (ap1_y - grad_y - mode);
|
||||
|
@ -195,33 +194,32 @@ void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const
|
|||
|
||||
template <typename T>
|
||||
void MirrorPadGradCPUKernel::MirrorPadGradBatchChannel(const size_t size, T *dy, T *interim_dy, const int dx_batches,
|
||||
const int dx_channels, const int dx_height, const int dx_width,
|
||||
const int dy_height, const int dy_width, const int padd_dim,
|
||||
const int64_t *paddings_arg, int mode, T *dx) {
|
||||
const int dx_channels, const int dy_height, const int dy_width,
|
||||
const int padd_dim, const int64_t *paddings_arg, int mode) {
|
||||
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers
|
||||
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) {
|
||||
paddings[i] = 0; // init all to 0
|
||||
}
|
||||
extract_paddings_(paddings_arg, padd_dim, paddings);
|
||||
// Create anchor points for non mirrored data inside new tensor
|
||||
int ap1_channel = paddings[CHANNEL + LEFT];
|
||||
int ap2_channel = paddings[CHANNEL + LEFT] + dx_channels - 1;
|
||||
int ap1_batch = paddings[BATCH + LEFT];
|
||||
int ap2_batch = paddings[BATCH + LEFT] + dx_batches - 1;
|
||||
int dy_channels = dx_channels + paddings[CHANNEL + LEFT] + paddings[CHANNEL + RIGHT];
|
||||
int dy_batches = dx_batches + paddings[BATCH + LEFT] + paddings[BATCH + RIGHT];
|
||||
int ap1_channel = paddings[CHANNEL];
|
||||
int ap2_channel = paddings[CHANNEL] + dx_channels - 1;
|
||||
int ap1_batch = paddings[BATCH];
|
||||
int ap2_batch = paddings[BATCH] + dx_batches - 1;
|
||||
int dy_channels = dx_channels + paddings[CHANNEL] + paddings[CHANNEL + RIGHT];
|
||||
int dy_batches = dx_batches + paddings[BATCH] + paddings[RIGHT];
|
||||
|
||||
for (size_t pos = 0; pos < size; ++pos) {
|
||||
int block_num = (pos / dy_width) / dy_height;
|
||||
int block_num = (SizeToLong(pos) / dy_width) / dy_height;
|
||||
// Select exact position inside the dy_interim array
|
||||
const int interim_x = pos % dy_width;
|
||||
const int interim_y = (pos / dy_width) % dy_height;
|
||||
const int interim_x = SizeToLong(pos) % dy_width;
|
||||
const int interim_y = (SizeToLong(pos) / dy_width) % dy_height;
|
||||
const int interim_channel = block_num % dx_channels;
|
||||
const int interim_batch = block_num / dx_channels;
|
||||
interim_dy[pos] = T(0); // init
|
||||
// map cur interim channel and batch to equivalent in padded dy array
|
||||
const int equiv_dy_channel = interim_channel + paddings[CHANNEL + LEFT];
|
||||
const int equiv_dy_batch = interim_batch + paddings[BATCH + LEFT];
|
||||
const int equiv_dy_channel = interim_channel + paddings[CHANNEL];
|
||||
const int equiv_dy_batch = interim_batch + paddings[BATCH];
|
||||
int target_batch = 0;
|
||||
int target_channel = 0;
|
||||
int equiv_block_num = 0;
|
||||
|
@ -258,13 +256,11 @@ void MirrorPadGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
auto interim = reinterpret_cast<T *>(workspace[0]->addr);
|
||||
auto outputs_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
MirrorPadGradBatchChannel(workspace_size_, inputs_addr, interim, output_shape_[0], output_shape_[1], output_shape_[2],
|
||||
output_shape_[3], input_shape_[2], input_shape_[3], num_paddings_, paddings, mode_,
|
||||
outputs_addr);
|
||||
MirrorPadGradBatchChannel(workspace_size_, inputs_addr, interim, output_shape_[0], output_shape_[1], input_shape_[2],
|
||||
input_shape_[3], num_paddings_, paddings, mode_);
|
||||
|
||||
MirrorPadGrad_Width_Height(output_size_, inputs_addr, interim, output_shape_[0], output_shape_[1], output_shape_[2],
|
||||
output_shape_[3], input_shape_[2], input_shape_[3], num_paddings_, paddings, mode_,
|
||||
outputs_addr);
|
||||
MirrorPadGrad_Width_Height(output_size_, interim, output_shape_[2], output_shape_[3], input_shape_[2],
|
||||
input_shape_[3], num_paddings_, paddings, mode_, outputs_addr);
|
||||
}
|
||||
|
||||
void MirrorPadGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
|
|
|
@ -58,14 +58,14 @@ class MirrorPadGradCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void MirrorPadGrad_Width_Height(const size_t size, const T *dy, T *interim_dy, const int dx_batches,
|
||||
const int dx_channels, const int dx_height, const int dx_width, const int dy_height,
|
||||
const int dy_width, const int padd_dim, const int64_t *paddings_arg, int mode, T *dx);
|
||||
void MirrorPadGrad_Width_Height(const size_t size, T *interim_dy, const int dx_height, const int dx_width,
|
||||
const int dy_height, const int dy_width, const int padd_dim,
|
||||
const int64_t *paddings_arg, int mode, T *dx);
|
||||
|
||||
template <typename T>
|
||||
void MirrorPadGradBatchChannel(const size_t size, T *dy, T *interim_dy, const int dx_batches, const int dx_channels,
|
||||
const int dx_height, const int dx_width, const int dy_height, const int dy_width,
|
||||
const int padd_dim, const int64_t *paddings_arg, int mode, T *dx);
|
||||
const int dy_height, const int dy_width, const int padd_dim,
|
||||
const int64_t *paddings_arg, int mode);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
|
|
|
@ -52,9 +52,9 @@ void SigmoidCrossEntropyWithLogitsGradCPUKernel::LaunchKernel(const std::vector<
|
|||
T one = (T)1.0;
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
if (logits_addr[i] >= zero) {
|
||||
output_addr[i] = (one / (one + exp(-logits_addr[i])) - labels_addr[i]) * dloss_addr[i];
|
||||
output_addr[i] = (one / (one + static_cast<T>(exp(-logits_addr[i]))) - labels_addr[i]) * dloss_addr[i];
|
||||
} else {
|
||||
const T exp_val = exp(logits_addr[i]);
|
||||
const T exp_val = static_cast<T>(exp(logits_addr[i]));
|
||||
output_addr[i] = (exp_val / (one + exp_val) - labels_addr[i]) * dloss_addr[i];
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue