fix bug of USC infer

This commit is contained in:
mengyuanli 2022-11-24 14:25:11 +08:00
parent acc1c3f127
commit 59798e9e3a
2 changed files with 5 additions and 2 deletions

View File

@ -195,7 +195,6 @@ bool UniformCandidateSamplerCpuKernelMod::Init(const BaseOperatorPtr &base_opera
init_seed_ = LongToUint(seed_);
// check the attribute, inputs and outputs
CheckAttribute();
CheckInputsAndOutputs(inputs, outputs);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
@ -211,6 +210,7 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
if (ret != 0) {
return ret;
}
CheckInputsAndOutputs(inputs, outputs);
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());

View File

@ -65,7 +65,10 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name);
if (!IsDynamic(input_shape)) {
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1],
op_name);
}
if (unique) {
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
}