fix maskedfill resize

This commit is contained in:
fan-jibin 2022-12-26 19:23:27 +08:00
parent 9c2e9e9c85
commit 95cf7e641b
2 changed files with 11 additions and 17 deletions

View File

@ -51,14 +51,11 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
std::vector<int64_t> input_shape = inputs.at(kIndex0)->GetShapeVector();
std::vector<int64_t> mask_shape = inputs.at(kIndex1)->GetShapeVector();
std::vector<int64_t> value_shape = inputs.at(kIndex2)->GetShapeVector();
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
(void)std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(mask_shape_), LongToSize);
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
ShapeVector input_shape = inputs.at(kIndex0)->GetShapeVector();
ShapeVector mask_shape = inputs.at(kIndex1)->GetShapeVector();
ShapeVector value_shape = inputs.at(kIndex2)->GetShapeVector();
ShapeVector output_shape = outputs.at(kIndex0)->GetShapeVector();
need_broadcast_ = (input_shape == mask_shape) ? false : true;
size_t batch_size = value_shape.size();
if (LongToSize(batch_rank_) != batch_size) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value shape size should equal to " << batch_rank_
@ -78,18 +75,18 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
}
}
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
output_size_ = std::accumulate(output_shape.begin(), output_shape.end(), size_t(1), std::multiplies<size_t>());
value_size_ =
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
MS_EXCEPTION_IF_ZERO("value_size", value_size_);
inner_size_ = output_size_ / value_size_;
MS_EXCEPTION_IF_ZERO("inner_size", inner_size_);
mask_index_.clear();
input_index_.clear();
mask_index_.resize(output_size_);
input_index_.resize(output_size_);
if (need_broadcast_) {
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
mask_index_.clear();
input_index_.clear();
mask_index_.resize(output_size_);
input_index_.resize(output_size_);
BroadcastIterator base_iter(input_shape, mask_shape, output_shape);
base_iter.SetPos(0);
for (size_t i = 0; i < output_size_; i++) {
mask_index_[i] = base_iter.GetInputPosB();

View File

@ -53,9 +53,6 @@ class MaskedFillCpuKernelMod : public NativeCpuKernelMod {
size_t inner_size_{1};
size_t value_size_{1};
int64_t batch_rank_{0};
std::vector<int64_t> input_shape_;
std::vector<int64_t> mask_shape_;
std::vector<int64_t> output_shape_;
std::vector<size_t> mask_index_;
std::vector<size_t> input_index_;
bool need_broadcast_{false};