forked from mindspore-Ecosystem/mindspore
fix bug of USC infer
This commit is contained in:
parent
acc1c3f127
commit
59798e9e3a
|
@ -195,7 +195,6 @@ bool UniformCandidateSamplerCpuKernelMod::Init(const BaseOperatorPtr &base_opera
|
|||
init_seed_ = LongToUint(seed_);
|
||||
// check the attribute, inputs and outputs
|
||||
CheckAttribute();
|
||||
CheckInputsAndOutputs(inputs, outputs);
|
||||
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
|
@ -211,6 +210,7 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
CheckInputsAndOutputs(inputs, outputs);
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
|
||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
|
|
|
@ -65,7 +65,10 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
|
|||
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name);
|
||||
if (!IsDynamic(input_shape)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1],
|
||||
op_name);
|
||||
}
|
||||
if (unique) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue