From 59798e9e3aafddea32b14b8d477a5715b4009af1 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Thu, 24 Nov 2022 14:25:11 +0800 Subject: [PATCH] fix bug of USC infer --- .../cpu/kernel/uniform_candidate_sampler_cpu_kernel.cc | 2 +- mindspore/core/ops/uniform_candidate_sampler.cc | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_candidate_sampler_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_candidate_sampler_cpu_kernel.cc index ac39a098a99..ed612e413e6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_candidate_sampler_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_candidate_sampler_cpu_kernel.cc @@ -195,7 +195,6 @@ bool UniformCandidateSamplerCpuKernelMod::Init(const BaseOperatorPtr &base_opera init_seed_ = LongToUint(seed_); // check the attribute, inputs and outputs CheckAttribute(); - CheckInputsAndOutputs(inputs, outputs); if (!MatchKernelFunc(base_operator, inputs, outputs)) { return false; @@ -211,6 +210,7 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper if (ret != 0) { return ret; } + CheckInputsAndOutputs(inputs, outputs); auto output_shape = outputs.at(kIndex0)->GetShapeVector(); batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); diff --git a/mindspore/core/ops/uniform_candidate_sampler.cc b/mindspore/core/ops/uniform_candidate_sampler.cc index 12ac7cd1a7f..ba3271ba2ae 100644 --- a/mindspore/core/ops/uniform_candidate_sampler.cc +++ b/mindspore/core/ops/uniform_candidate_sampler.cc @@ -65,7 +65,10 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std:: int64_t range_max = GetValue(primitive->GetAttr("range_max")); (void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name); (void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name); - (void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name); + if (!IsDynamic(input_shape)) { + (void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], + op_name); + } if (unique) { (void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name); }