forked from mindspore-Ecosystem/mindspore
fix bug of USC infer
This commit is contained in:
parent
acc1c3f127
commit
59798e9e3a
|
@ -195,7 +195,6 @@ bool UniformCandidateSamplerCpuKernelMod::Init(const BaseOperatorPtr &base_opera
|
||||||
init_seed_ = LongToUint(seed_);
|
init_seed_ = LongToUint(seed_);
|
||||||
// check the attribute, inputs and outputs
|
// check the attribute, inputs and outputs
|
||||||
CheckAttribute();
|
CheckAttribute();
|
||||||
CheckInputsAndOutputs(inputs, outputs);
|
|
||||||
|
|
||||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -211,6 +210,7 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
CheckInputsAndOutputs(inputs, outputs);
|
||||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||||
|
|
||||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||||
|
|
|
@ -65,7 +65,10 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
|
||||||
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
|
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
|
||||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
|
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
|
||||||
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
|
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
|
||||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name);
|
if (!IsDynamic(input_shape)) {
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1],
|
||||||
|
op_name);
|
||||||
|
}
|
||||||
if (unique) {
|
if (unique) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
|
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue