!40535 fix gpu op histogram_fixed_width bug

Merge pull request !40535 from jin_jiaqi/histo
This commit is contained in:
i-robot 2022-08-26 01:21:55 +00:00 committed by Gitee
commit 4a1f1aff5b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 4 additions and 4 deletions

View File

@ -51,12 +51,12 @@ void HistogramFixedWidth::set_nbins(const int32_t nbins) {
(void)this->AddAttr(kNbins, api::MakeValue(nbins));
}
void HistogramFixedWidth::set_dtype(const TypeId dtype) { (void)this->AddAttr(kDType, api::Type::GetType(dtype)); }
void HistogramFixedWidth::set_dtype(const TypeId dtype) { (void)this->AddAttr("dtype", api::Type::GetType(dtype)); }
int32_t HistogramFixedWidth::get_nbins() const { return static_cast<int32_t>(GetValue<int64_t>(GetAttr(kNbins))); }
TypeId HistogramFixedWidth::get_dtype() const {
return GetAttr(kDType)->cast<api::TensorTypePtr>()->element()->type_id();
return GetAttr("dtype")->cast<api::TensorTypePtr>()->element()->type_id();
}
void HistogramFixedWidth::Init(const int32_t nbins, const TypeId dtype) {

View File

@ -2600,9 +2600,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
valid_values = ['int32']
self.dtype = validator.check_string(dtype, valid_values, "d_type", self.name)
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
self.add_prim_attr('d_type', 3)
self.add_prim_attr('dtype', 3)