forked from mindspore-Ecosystem/mindspore
!42840 Fix bias_add problem
Merge pull request !42840 from LiangZhibo/bias
This commit is contained in:
commit
9edd678164
|
@ -45,6 +45,12 @@ int BiasAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
||||||
if (ret != KRET_OK) {
|
if (ret != KRET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
auto input_shape = input->GetShapeVector();
|
||||||
|
if (!IsValidShape(input_shape)) {
|
||||||
|
return KRET_UNKNOWN_SHAPE;
|
||||||
|
}
|
||||||
|
}
|
||||||
input_shape_ = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
|
input_shape_ = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
|
||||||
bias_shape_ = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
|
bias_shape_ = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
|
||||||
data_shape_ = input_shape_.size();
|
data_shape_ = input_shape_.size();
|
||||||
|
@ -143,6 +149,7 @@ const std::vector<std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc>> &Bi
|
||||||
MakeKernelFunc<int8_t>(kNumberTypeInt8),
|
MakeKernelFunc<int8_t>(kNumberTypeInt8),
|
||||||
MakeKernelFunc<int16_t>(kNumberTypeInt16),
|
MakeKernelFunc<int16_t>(kNumberTypeInt16),
|
||||||
MakeKernelFunc<int64_t>(kNumberTypeInt32),
|
MakeKernelFunc<int64_t>(kNumberTypeInt32),
|
||||||
|
MakeKernelFunc<int64_t>(kNumberTypeInt64),
|
||||||
MakeKernelFunc<uint8_t>(kNumberTypeUInt8),
|
MakeKernelFunc<uint8_t>(kNumberTypeUInt8),
|
||||||
MakeKernelFunc<uint16_t>(kNumberTypeUInt16),
|
MakeKernelFunc<uint16_t>(kNumberTypeUInt16),
|
||||||
MakeKernelFunc<uint32_t>(kNumberTypeUInt32),
|
MakeKernelFunc<uint32_t>(kNumberTypeUInt32),
|
||||||
|
|
|
@ -45,6 +45,12 @@ int BiasAddGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
||||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
auto input_shape = input->GetShapeVector();
|
||||||
|
if (!IsValidShape(input_shape)) {
|
||||||
|
return KRET_UNKNOWN_SHAPE;
|
||||||
|
}
|
||||||
|
}
|
||||||
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||||
auto num_dims = x_shape.size();
|
auto num_dims = x_shape.size();
|
||||||
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input_x");
|
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input_x");
|
||||||
|
|
Loading…
Reference in New Issue