merge
This commit is contained in:
parent
ab7e20c870
commit
56e8d4ef9d
|
@ -64,7 +64,7 @@ Status BiasAddInfo::InferTensorMap() {
|
|||
for (size_t i = 0; i < sub_a_strategy_size; ++i) {
|
||||
sub_a_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size) - i));
|
||||
}
|
||||
sub_b_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size) - static_cast<int64_t>(1)));
|
||||
sub_b_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size)) - static_cast<int64_t>(1));
|
||||
|
||||
inputs_tensor_map_.push_back(sub_a_tensor_map);
|
||||
inputs_tensor_map_.push_back(sub_b_tensor_map);
|
||||
|
|
|
@ -93,7 +93,7 @@ void NoRepeatNGramCpuKernelMod::CheckAndInitParams() {
|
|||
|
||||
template <typename T>
|
||||
bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
|
||||
|
@ -113,14 +113,15 @@ bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
|||
int64_t output_index_i = output_dim_ * i;
|
||||
for (int64_t k = 0; k < state_dim_; k++) {
|
||||
int64_t src_index_k = k + src_index_i;
|
||||
array_dim[k] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||
array_dim[LongToSize(k)] = static_cast<int32_t>(state_seq[LongToSize(src_index_k)]);
|
||||
if (k > (state_dim_ - ngram_size_)) {
|
||||
array_ngram[k + ngram_size_ - state_dim_ - 1] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||
array_ngram[LongToSize(k + ngram_size_ - state_dim_ - 1)] =
|
||||
static_cast<int32_t>(state_seq[LongToSize(src_index_k)]);
|
||||
}
|
||||
}
|
||||
for (int64_t j = 0; j < state_dim_ - ngram_size_ + 1; j++) {
|
||||
if (equal(array_ngram.begin(), array_ngram.end(), array_dim.begin() + j)) {
|
||||
int64_t output_index_j = static_cast<int64_t>(array_dim[j + ngram_size_ - 1]);
|
||||
int64_t output_index_j = static_cast<int64_t>(array_dim[LongToSize(j + ngram_size_ - 1)]);
|
||||
output[output_index_i + output_index_j] = -(std::numeric_limits<T>::max)();
|
||||
}
|
||||
}
|
||||
|
@ -137,7 +138,7 @@ std::vector<std::pair<KernelAttr, NoRepeatNGramCpuKernelMod::NoRepeatNGramFunc>>
|
|||
&NoRepeatNGramCpuKernelMod::LaunchKernel<float16>}};
|
||||
|
||||
std::vector<KernelAttr> NoRepeatNGramCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, NoRepeatNGramFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
|
|
Loading…
Reference in New Issue