forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop] fix bug of fp32 grad op: unsorted_segmaent_sum
This commit is contained in:
parent
4e9002329f
commit
a0e0bb5f68
|
@ -38,6 +38,8 @@ int UnsortedSegmentSumCPUKernel::Init() {
|
|||
auto input_shape = in_tensors_.at(0)->shape();
|
||||
auto segment_ids_shape = in_tensors_.at(1)->shape();
|
||||
auto output_shape = out_tensors_.at(0)->shape();
|
||||
unit_num_ = 1;
|
||||
input_dim1_ = 1;
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
unit_num_ *= input_shape[i];
|
||||
if (i >= segment_ids_shape.size()) {
|
||||
|
@ -45,6 +47,7 @@ int UnsortedSegmentSumCPUKernel::Init() {
|
|||
}
|
||||
}
|
||||
output_dim0_ = output_shape[0];
|
||||
output_dim1_ = 1;
|
||||
for (size_t j = 1; j < output_shape.size(); j++) {
|
||||
output_dim1_ *= output_shape[j];
|
||||
}
|
||||
|
|
|
@ -32,10 +32,10 @@ class UnsortedSegmentSumCPUKernel : public LiteKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Execute(int task_id);
|
||||
size_t unit_num_;
|
||||
size_t input_dim1_;
|
||||
size_t output_dim0_;
|
||||
size_t output_dim1_;
|
||||
size_t unit_num_ = 0;
|
||||
size_t input_dim1_ = 0;
|
||||
size_t output_dim0_ = 0;
|
||||
size_t output_dim1_ = 0;
|
||||
|
||||
private:
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue