forked from mindspore-Ecosystem/mindspore
!42840 Fix bias_add problem
Merge pull request !42840 from LiangZhibo/bias
This commit is contained in:
commit
9edd678164
|
@ -45,6 +45,12 @@ int BiasAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
input_shape_ = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
|
||||
bias_shape_ = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
|
||||
data_shape_ = input_shape_.size();
|
||||
|
@ -143,6 +149,7 @@ const std::vector<std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc>> &Bi
|
|||
MakeKernelFunc<int8_t>(kNumberTypeInt8),
|
||||
MakeKernelFunc<int16_t>(kNumberTypeInt16),
|
||||
MakeKernelFunc<int64_t>(kNumberTypeInt32),
|
||||
MakeKernelFunc<int64_t>(kNumberTypeInt64),
|
||||
MakeKernelFunc<uint8_t>(kNumberTypeUInt8),
|
||||
MakeKernelFunc<uint16_t>(kNumberTypeUInt16),
|
||||
MakeKernelFunc<uint32_t>(kNumberTypeUInt32),
|
||||
|
|
|
@ -45,6 +45,12 @@ int BiasAddGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||
auto num_dims = x_shape.size();
|
||||
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input_x");
|
||||
|
|
Loading…
Reference in New Issue