add the no_repeat_ngram check

This commit is contained in:
bichaoyang 2022-11-04 14:35:19 +08:00
parent 0d0e70a433
commit 7a433f811d
1 changed files with 23 additions and 4 deletions

View File

@ -31,15 +31,34 @@ namespace ops {
namespace {
abstract::ShapePtr NoRepeatNGramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto in_shape = shape_map[kShape];
if (IsDynamicRank(in_shape)) {
auto ngram_size = GetValue<int64_t>(primitive->GetAttr(kNgramSize));
const int64_t kShapeSize = 3;
constexpr int64_t kIndex0 = 0;
constexpr int64_t kIndex1 = 1;
constexpr int64_t kIndex2 = 2;
auto state_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto state_shape = state_shape_map[kShape];
auto log_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_shape = log_shape_map[kShape];
if (IsDynamicRank(log_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
return std::make_shared<abstract::Shape>(in_shape);
(void)CheckAndConvertUtils::CheckInteger("rank of state_seq", SizeToLong(state_shape.size()), kEqual, kShapeSize,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of log_probs", SizeToLong(log_shape.size()), kEqual, kShapeSize,
prim_name);
(void)CheckAndConvertUtils::CheckValue("state_seq shape[0]", state_shape.at(kIndex0), kEqual, "log_probs shape[0]",
log_shape.at(kIndex0), prim_name);
(void)CheckAndConvertUtils::CheckValue("state_seq shape[1]", state_shape.at(kIndex1), kEqual, "log_probs shape[1]",
log_shape.at(kIndex1), prim_name);
(void)CheckAndConvertUtils::CheckValue("ngram_size", ngram_size, kLessEqual, "state_seq shape[2] + 1",
state_shape.at(kIndex2), prim_name);
return std::make_shared<abstract::Shape>(log_shape);
}
TypePtr NoRepeatNGramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);