forked from mindspore-Ecosystem/mindspore
!40535 fix gpu op histogram_fixed_width bug
Merge pull request !40535 from jin_jiaqi/histo
This commit is contained in:
commit
4a1f1aff5b
|
@ -51,12 +51,12 @@ void HistogramFixedWidth::set_nbins(const int32_t nbins) {
|
|||
(void)this->AddAttr(kNbins, api::MakeValue(nbins));
|
||||
}
|
||||
|
||||
void HistogramFixedWidth::set_dtype(const TypeId dtype) { (void)this->AddAttr(kDType, api::Type::GetType(dtype)); }
|
||||
void HistogramFixedWidth::set_dtype(const TypeId dtype) { (void)this->AddAttr("dtype", api::Type::GetType(dtype)); }
|
||||
|
||||
int32_t HistogramFixedWidth::get_nbins() const { return static_cast<int32_t>(GetValue<int64_t>(GetAttr(kNbins))); }
|
||||
|
||||
TypeId HistogramFixedWidth::get_dtype() const {
|
||||
return GetAttr(kDType)->cast<api::TensorTypePtr>()->element()->type_id();
|
||||
return GetAttr("dtype")->cast<api::TensorTypePtr>()->element()->type_id();
|
||||
}
|
||||
|
||||
void HistogramFixedWidth::Init(const int32_t nbins, const TypeId dtype) {
|
||||
|
|
|
@ -2600,9 +2600,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
|
|||
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
|
||||
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
|
||||
valid_values = ['int32']
|
||||
self.dtype = validator.check_string(dtype, valid_values, "d_type", self.name)
|
||||
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
|
||||
self.add_prim_attr('d_type', 3)
|
||||
self.add_prim_attr('dtype', 3)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue