!34255 fix maximum_grad bug for gpu backend.

Merge pull request !34255 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-05-12 01:17:41 +00:00 committed by Gitee
commit f67a6d2d31
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 42 additions and 22 deletions

View File

@ -43,7 +43,7 @@ bool BroadcastOpGradGpuKernelMod::GetOpType() {
};
auto iter = kBroadcastTypeMap.find(kernel_name_);
if (iter == kBroadcastTypeMap.end()) {
MS_LOG(ERROR) << "For 'MaximumGrad' or 'MinimumGrad' only support max and min grad, but got " << kernel_name_;
MS_LOG(ERROR) << "For 'MaximumGrad' or 'MinimumGrad', it only support max and min grad, but got " << kernel_name_;
return false;
} else {
op_type_ = iter->second;
@ -55,7 +55,7 @@ bool BroadcastOpGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, con
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
if (!GetOpType()) {
@ -73,7 +73,7 @@ bool BroadcastOpGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, con
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
@ -84,14 +84,11 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
ResetResource();
unit_size_ = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
std::vector<size_t> shape0;
auto origin_shape0 = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(origin_shape0.begin(), origin_shape0.end(), std::back_inserter(shape0), LongToSize);
@ -105,13 +102,12 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
is_null_input_ = CHECK_SHAPE_NULL(shape0, kernel_name_, "input_0") ||
CHECK_SHAPE_NULL(shape1, kernel_name_, "input_1") ||
CHECK_SHAPE_NULL(shape2, kernel_name_, "input_2");
unit_size_ = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
if (is_null_input_) {
return KRET_OK;
}
need_broadcast_ = IsBroadcast(shape0, shape1);
if (need_broadcast_ && shape0.size() > kMaxShapeSize) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than " << kMaxShapeSize
<< ", but got " << shape0.size();
return KRET_RESIZE_FAILED;
}
@ -130,8 +126,8 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
x0_shape_[i + x0_offset] = shape0[i];
} else {
auto index = i + x0_offset;
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
<< ", but got " << (index + 1);
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than "
<< kMaxShapeSize << ", but got " << (index + 1);
return KRET_RESIZE_FAILED;
}
}
@ -144,8 +140,8 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
x1_shape_[i + x1_offset] = shape1[i];
} else {
auto index = i + x1_offset;
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
<< ", but got " << (index + 1);
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than "
<< kMaxShapeSize << ", but got " << (index + 1);
return KRET_RESIZE_FAILED;
}
}
@ -204,15 +200,40 @@ bool BroadcastOpGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
}
std::vector<std::pair<KernelAttr, BroadcastOpGradGpuKernelMod::BroadCastFunc>> BroadcastOpGradGpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
{{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&BroadcastOpGradGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&BroadcastOpGradGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&BroadcastOpGradGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&BroadcastOpGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&BroadcastOpGradGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> BroadcastOpGradGpuKernelMod::GetOpSupport() {

View File

@ -64,7 +64,6 @@ class BroadcastOpGradGpuKernelMod : public NativeGpuKernelMod {
using BroadCastFunc = std::function<bool(BroadcastOpGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
private:
BroadcastGradOpType op_type_{BROADCAST_GRAD_TYPE_INVALID};
bool need_broadcast_{false};
bool is_null_input_{false};