fix maskedfill resize
This commit is contained in:
parent
9c2e9e9c85
commit
95cf7e641b
|
@ -51,14 +51,11 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
|
||||
return ret;
|
||||
}
|
||||
std::vector<int64_t> input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<int64_t> mask_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
std::vector<int64_t> value_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
(void)std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(mask_shape_), LongToSize);
|
||||
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
||||
need_broadcast_ = (input_shape_ == mask_shape_) ? false : true;
|
||||
ShapeVector input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
ShapeVector mask_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
ShapeVector value_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
ShapeVector output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
need_broadcast_ = (input_shape == mask_shape) ? false : true;
|
||||
size_t batch_size = value_shape.size();
|
||||
if (LongToSize(batch_rank_) != batch_size) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value shape size should equal to " << batch_rank_
|
||||
|
@ -78,18 +75,18 @@ int MaskedFillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
}
|
||||
}
|
||||
|
||||
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||
output_size_ = std::accumulate(output_shape.begin(), output_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
value_size_ =
|
||||
LongToSize(std::accumulate(value_shape.begin(), value_shape.end(), int64_t(1), std::multiplies<int64_t>()));
|
||||
MS_EXCEPTION_IF_ZERO("value_size", value_size_);
|
||||
inner_size_ = output_size_ / value_size_;
|
||||
MS_EXCEPTION_IF_ZERO("inner_size", inner_size_);
|
||||
mask_index_.clear();
|
||||
input_index_.clear();
|
||||
mask_index_.resize(output_size_);
|
||||
input_index_.resize(output_size_);
|
||||
if (need_broadcast_) {
|
||||
BroadcastIterator base_iter(input_shape_, mask_shape_, output_shape_);
|
||||
mask_index_.clear();
|
||||
input_index_.clear();
|
||||
mask_index_.resize(output_size_);
|
||||
input_index_.resize(output_size_);
|
||||
BroadcastIterator base_iter(input_shape, mask_shape, output_shape);
|
||||
base_iter.SetPos(0);
|
||||
for (size_t i = 0; i < output_size_; i++) {
|
||||
mask_index_[i] = base_iter.GetInputPosB();
|
||||
|
|
|
@ -53,9 +53,6 @@ class MaskedFillCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t inner_size_{1};
|
||||
size_t value_size_{1};
|
||||
int64_t batch_rank_{0};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> mask_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<size_t> mask_index_;
|
||||
std::vector<size_t> input_index_;
|
||||
bool need_broadcast_{false};
|
||||
|
|
Loading…
Reference in New Issue