forked from mindspore-Ecosystem/mindspore
[assistant] [I4XJI5]fix ctclossv2 cpu bug & doc
This commit is contained in:
parent
52ee32f20d
commit
166c955fea
|
@ -36,6 +36,7 @@ bool CTCLossV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
}
|
||||
auto kernel_ptr = std::make_shared<ops::CTCLossV2>(base_operator->GetPrim());
|
||||
blank_ = kernel_ptr->get_blank();
|
||||
zero_infinity_ = kernel_ptr->get_zero_infinity();
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -190,6 +191,10 @@ bool CTCLossV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
|
|||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, LongToSize(batch_sizes_), this, ¶llel_search_info_);
|
||||
if (zero_infinity_) {
|
||||
constexpr S zero = static_cast<S>(0);
|
||||
std::replace(neg_log_p, neg_log_p + batch_sizes_, std::numeric_limits<S>::infinity(), zero);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
const std::vector<std::pair<KernelAttr, CTCLossV2CpuKernelMod::KernelRunFunc>> &CTCLossV2CpuKernelMod::GetFuncList()
|
||||
|
|
|
@ -76,6 +76,7 @@ class CTCLossV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
|
|||
// Stands for C
|
||||
int64_t num_labels_{0};
|
||||
// Stands for S
|
||||
bool zero_infinity_{false};
|
||||
int64_t max_target_length_{0};
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
|
|
|
@ -27,6 +27,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
int64_t CTCLossV2::get_blank() const { return GetValue<int64_t>(GetAttr(kAttrBlank)); }
|
||||
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
|
||||
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
|
||||
namespace {
|
||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -48,6 +48,11 @@ class MIND_API CTCLossV2 : public BaseOperator {
|
|||
///
|
||||
/// \return reduction.
|
||||
std::string get_reduction() const;
|
||||
|
||||
/// \brief Get zero_infinity.
|
||||
///
|
||||
/// \return zero_infinity.
|
||||
bool get_zero_infinity() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -2048,7 +2048,7 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
|
|||
RuntimeError: If the shape of `target_lengths` does not match {batch_size|N}.
|
||||
RuntimeError: If the types of `targets`, `input_lengths`, `grad_out` or `target_lengths` are different.
|
||||
RuntimeError: If the value of `blank` is not in range [0, num_labels|C).
|
||||
RuntimeError: If any value of `input_lengths` is larger than (num_labels|C).
|
||||
RuntimeError: If any value of `input_lengths` is larger than (time_series|T).
|
||||
RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]].
|
||||
|
||||
Supported Platforms:
|
||||
|
|
|
@ -8383,7 +8383,7 @@ class CTCLossV2(Primitive):
|
|||
RuntimeError: If the shape of `target_lengths` does not match {batch_size|N}.
|
||||
RuntimeError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
|
||||
RuntimeError: If the value of `blank` is not in range [0, num_labels|C).
|
||||
RuntimeError: If any value of `input_lengths` is larger than (num_labels|C).
|
||||
RuntimeError: If any value of `input_lengths` is larger than (time_series|T).
|
||||
RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]].
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue