This commit is contained in:
bichaoyang 2022-08-08 14:42:38 +08:00
parent ab7e20c870
commit 56e8d4ef9d
2 changed files with 7 additions and 6 deletions

View File

@ -64,7 +64,7 @@ Status BiasAddInfo::InferTensorMap() {
for (size_t i = 0; i < sub_a_strategy_size; ++i) {
sub_a_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size) - i));
}
sub_b_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size) - static_cast<int64_t>(1)));
sub_b_tensor_map.push_back(static_cast<int64_t>(LAST_INDEX(sub_a_strategy_size)) - static_cast<int64_t>(1));
inputs_tensor_map_.push_back(sub_a_tensor_map);
inputs_tensor_map_.push_back(sub_b_tensor_map);

View File

@ -93,7 +93,7 @@ void NoRepeatNGramCpuKernelMod::CheckAndInitParams() {
template <typename T>
bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
@ -113,14 +113,15 @@ bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
int64_t output_index_i = output_dim_ * i;
for (int64_t k = 0; k < state_dim_; k++) {
int64_t src_index_k = k + src_index_i;
array_dim[k] = static_cast<int32_t>(state_seq[src_index_k]);
array_dim[LongToSize(k)] = static_cast<int32_t>(state_seq[LongToSize(src_index_k)]);
if (k > (state_dim_ - ngram_size_)) {
array_ngram[k + ngram_size_ - state_dim_ - 1] = static_cast<int32_t>(state_seq[src_index_k]);
array_ngram[LongToSize(k + ngram_size_ - state_dim_ - 1)] =
static_cast<int32_t>(state_seq[LongToSize(src_index_k)]);
}
}
for (int64_t j = 0; j < state_dim_ - ngram_size_ + 1; j++) {
if (equal(array_ngram.begin(), array_ngram.end(), array_dim.begin() + j)) {
int64_t output_index_j = static_cast<int64_t>(array_dim[j + ngram_size_ - 1]);
int64_t output_index_j = static_cast<int64_t>(array_dim[LongToSize(j + ngram_size_ - 1)]);
output[output_index_i + output_index_j] = -(std::numeric_limits<T>::max)();
}
}
@ -137,7 +138,7 @@ std::vector<std::pair<KernelAttr, NoRepeatNGramCpuKernelMod::NoRepeatNGramFunc>>
&NoRepeatNGramCpuKernelMod::LaunchKernel<float16>}};
std::vector<KernelAttr> NoRepeatNGramCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, NoRepeatNGramFunc> &pair) { return pair.first; });
return support_list;