add the no_repeat_ngram check
This commit is contained in:
parent
0d0e70a433
commit
7a433f811d
|
@ -31,15 +31,34 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr NoRepeatNGramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
if (IsDynamicRank(in_shape)) {
|
||||
auto ngram_size = GetValue<int64_t>(primitive->GetAttr(kNgramSize));
|
||||
const int64_t kShapeSize = 3;
|
||||
constexpr int64_t kIndex0 = 0;
|
||||
constexpr int64_t kIndex1 = 1;
|
||||
constexpr int64_t kIndex2 = 2;
|
||||
auto state_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto state_shape = state_shape_map[kShape];
|
||||
auto log_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_shape = log_shape_map[kShape];
|
||||
if (IsDynamicRank(log_shape)) {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of state_seq", SizeToLong(state_shape.size()), kEqual, kShapeSize,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of log_probs", SizeToLong(log_shape.size()), kEqual, kShapeSize,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckValue("state_seq shape[0]", state_shape.at(kIndex0), kEqual, "log_probs shape[0]",
|
||||
log_shape.at(kIndex0), prim_name);
|
||||
(void)CheckAndConvertUtils::CheckValue("state_seq shape[1]", state_shape.at(kIndex1), kEqual, "log_probs shape[1]",
|
||||
log_shape.at(kIndex1), prim_name);
|
||||
(void)CheckAndConvertUtils::CheckValue("ngram_size", ngram_size, kLessEqual, "state_seq shape[2] + 1",
|
||||
state_shape.at(kIndex2), prim_name);
|
||||
return std::make_shared<abstract::Shape>(log_shape);
|
||||
}
|
||||
TypePtr NoRepeatNGramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
|
Loading…
Reference in New Issue