!40083 [MS][OPS] UniformCandidateSampler add parameter's checking in C++ primitive
Merge pull request !40083 from louie5/master
This commit is contained in:
commit
4902e3f58e
|
@ -164,13 +164,13 @@ void UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
|
|||
}
|
||||
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
|
||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
batch_size_ = batch_size_ / num_sampled_;
|
||||
if (batch_size_ == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of output 'sampled_candidates' can not be 0";
|
||||
}
|
||||
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
input_size_ =
|
||||
LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
||||
input_size_ = input_size_ / LongToSize(batch_size_);
|
||||
|
||||
(void)output_sizes_.emplace_back(num_sampled_);
|
||||
|
@ -228,16 +228,18 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
|||
}
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
|
||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
||||
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
batch_size_ = batch_size_ / num_sampled_;
|
||||
if (batch_size_ == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of output 'sampled_candidates' can not be 0";
|
||||
}
|
||||
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
input_size_ =
|
||||
LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
||||
input_size_ = input_size_ / LongToSize(batch_size_);
|
||||
|
||||
output_sizes_.clear();
|
||||
(void)output_sizes_.emplace_back(num_sampled_);
|
||||
(void)output_sizes_.emplace_back(input_size_);
|
||||
(void)output_sizes_.emplace_back(num_sampled_);
|
||||
|
|
|
@ -54,6 +54,18 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
|
|||
op_name);
|
||||
}
|
||||
|
||||
int64_t num_true = GetValue<int64_t>(primitive->GetAttr("num_true"));
|
||||
int64_t seed = GetValue<int64_t>(primitive->GetAttr("seed"));
|
||||
bool unique = GetValue<bool>(primitive->GetAttr("unique"));
|
||||
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr("num_sampled"));
|
||||
int64_t range_max = GetValue<int64_t>(primitive->GetAttr("range_max"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kGreaterThan, 0, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("seed", seed, kGreaterEqual, 0, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_true", num_true, kEqual, input_shape[input_shape.size() - 1], op_name);
|
||||
if (unique) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_sampled", num_sampled, kLessEqual, range_max, op_name);
|
||||
}
|
||||
|
||||
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
|
||||
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
||||
auto true_expected_count_shape = input_shape_ptr;
|
||||
|
@ -61,7 +73,6 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
|
|||
true_expected_count_shape = std::make_shared<abstract::Shape>(input_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
auto num_sampled = GetValue<int64_t>(primitive->GetAttr("num_sampled"));
|
||||
std::vector<int64_t> batch_lists;
|
||||
for (int64_t i = 0; i < batch_rank; i++) {
|
||||
(void)batch_lists.emplace_back(input_shape[i]);
|
||||
|
|
|
@ -838,6 +838,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
|
|||
Validator.check_value_type("range_max", range_max, [int], self.name)
|
||||
Validator.check_value_type("seed", seed, [int], self.name)
|
||||
Validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
|
||||
Validator.check("value of num_true", num_true, '', 0, Rel.GT, self.name)
|
||||
Validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
|
||||
Validator.check("value of range_max", range_max, '', 0, Rel.GT, self.name)
|
||||
self.num_true = num_true
|
||||
|
|
Loading…
Reference in New Issue