!40083 [MS][OPS] UniformCandidateSampler add parameter's checking in C++ primitive

Merge pull request !40083 from louie5/master
This commit is contained in:
i-robot 2022-08-10 03:31:09 +00:00 committed by Gitee
commit 4902e3f58e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 20 additions and 6 deletions

View File

@ -164,13 +164,13 @@ void UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
}
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
batch_size_ = batch_size_ / num_sampled_;
if (batch_size_ == 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of output 'sampled_candidates' can not be 0";
}
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
input_size_ =
LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), int64_t(1), std::multiplies<int64_t>()));
input_size_ = input_size_ / LongToSize(batch_size_);
(void)output_sizes_.emplace_back(num_sampled_);
@ -228,16 +228,18 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
}
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
batch_size_ = batch_size_ / num_sampled_;
if (batch_size_ == 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of output 'sampled_candidates' can not be 0";
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
input_size_ =
LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), int64_t(1), std::multiplies<int64_t>()));
input_size_ = input_size_ / LongToSize(batch_size_);
output_sizes_.clear();
(void)output_sizes_.emplace_back(num_sampled_);
(void)output_sizes_.emplace_back(input_size_);
(void)output_sizes_.emplace_back(num_sampled_);

View File

@ -54,6 +54,18 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
op_name);
}
int64_t num_true = GetValue<int64_t>(primitive->GetAttr("num_true"));
int64_t seed = GetValue<int64_t>(primitive->GetAttr("seed"));
bool unique = GetValue<bool>(primitive->GetAttr("unique"));
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr("num_sampled"));
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name);
if (unique) {
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
}
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
auto true_expected_count_shape = input_shape_ptr;
@ -61,7 +73,6 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
true_expected_count_shape = std::make_shared<abstract::Shape>(input_shape, min_shape, max_shape);
}
auto num_sampled = GetValue<int64_t>(primitive->GetAttr("num_sampled"));
std::vector<int64_t> batch_lists;
for (int64_t i = 0; i < batch_rank; i++) {
(void)batch_lists.emplace_back(input_shape[i]);

View File

@ -838,6 +838,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
Validator.check_value_type("range_max", range_max, [int], self.name)
Validator.check_value_type("seed", seed, [int], self.name)
Validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
Validator.check("value of num_true", num_true, '', 0, Rel.GT, self.name)
Validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
Validator.check("value of range_max", range_max, '', 0, Rel.GT, self.name)
self.num_true = num_true