[assistant] [I4XJI5]fix ctclossv2 cpu bug & doc

This commit is contained in:
lh735291378 2022-08-31 17:49:06 +08:00
parent 52ee32f20d
commit 166c955fea
6 changed files with 14 additions and 2 deletions

View File

@ -36,6 +36,7 @@ bool CTCLossV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
}
auto kernel_ptr = std::make_shared<ops::CTCLossV2>(base_operator->GetPrim());
blank_ = kernel_ptr->get_blank();
zero_infinity_ = kernel_ptr->get_zero_infinity();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
@ -190,6 +191,10 @@ bool CTCLossV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
}
};
ParallelLaunchAutoSearch(task, LongToSize(batch_sizes_), this, &parallel_search_info_);
if (zero_infinity_) {
constexpr S zero = static_cast<S>(0);
std::replace(neg_log_p, neg_log_p + batch_sizes_, std::numeric_limits<S>::infinity(), zero);
}
return true;
}
const std::vector<std::pair<KernelAttr, CTCLossV2CpuKernelMod::KernelRunFunc>> &CTCLossV2CpuKernelMod::GetFuncList()

View File

@ -76,6 +76,7 @@ class CTCLossV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
// Stands for C
int64_t num_labels_{0};
// Stands for S
bool zero_infinity_{false};
int64_t max_target_length_{0};
template <typename T, typename S>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,

View File

@ -27,6 +27,7 @@ namespace mindspore {
namespace ops {
int64_t CTCLossV2::get_blank() const { return GetValue<int64_t>(GetAttr(kAttrBlank)); }
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
namespace {
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {

View File

@ -48,6 +48,11 @@ class MIND_API CTCLossV2 : public BaseOperator {
///
/// \return reduction.
std::string get_reduction() const;
/// \brief Get zero_infinity.
///
/// \return zero_infinity.
bool get_zero_infinity() const;
};
abstract::AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -2048,7 +2048,7 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
RuntimeError: If the shape of `target_lengths` does not match {batch_size|N}.
RuntimeError: If the types of `targets`, `input_lengths`, `grad_out` or `target_lengths` are different.
RuntimeError: If the value of `blank` is not in range [0, num_labels|C).
RuntimeError: If any value of `input_lengths` is larger than (num_labels|C).
RuntimeError: If any value of `input_lengths` is larger than (time_series|T).
RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]].
Supported Platforms:

View File

@ -8383,7 +8383,7 @@ class CTCLossV2(Primitive):
RuntimeError: If the shape of `target_lengths` does not match {batch_size|N}.
RuntimeError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
RuntimeError: If the value of `blank` is not in range [0, num_labels|C).
RuntimeError: If any value of `input_lengths` is larger than (num_labels|C).
RuntimeError: If any value of `input_lengths` is larger than (time_series|T).
RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]].
Supported Platforms: