forked from mindspore-Ecosystem/mindspore
!34255 fix maximum_grad bug for gpu backend.
Merge pull request !34255 from zhuzhongrui/pub_master
This commit is contained in:
commit
f67a6d2d31
|
@ -43,7 +43,7 @@ bool BroadcastOpGradGpuKernelMod::GetOpType() {
|
|||
};
|
||||
auto iter = kBroadcastTypeMap.find(kernel_name_);
|
||||
if (iter == kBroadcastTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "For 'MaximumGrad' or 'MinimumGrad' only support max and min grad, but got " << kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'MaximumGrad' or 'MinimumGrad', it only support max and min grad, but got " << kernel_name_;
|
||||
return false;
|
||||
} else {
|
||||
op_type_ = iter->second;
|
||||
|
@ -55,7 +55,7 @@ bool BroadcastOpGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, con
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
if (!GetOpType()) {
|
||||
|
@ -73,7 +73,7 @@ bool BroadcastOpGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, con
|
|||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
@ -84,14 +84,11 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ResetResource();
|
||||
unit_size_ = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
|
||||
std::vector<size_t> shape0;
|
||||
auto origin_shape0 = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(origin_shape0.begin(), origin_shape0.end(), std::back_inserter(shape0), LongToSize);
|
||||
|
@ -105,13 +102,12 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
is_null_input_ = CHECK_SHAPE_NULL(shape0, kernel_name_, "input_0") ||
|
||||
CHECK_SHAPE_NULL(shape1, kernel_name_, "input_1") ||
|
||||
CHECK_SHAPE_NULL(shape2, kernel_name_, "input_2");
|
||||
unit_size_ = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
need_broadcast_ = IsBroadcast(shape0, shape1);
|
||||
if (need_broadcast_ && shape0.size() > kMaxShapeSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than " << kMaxShapeSize
|
||||
<< ", but got " << shape0.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
@ -130,8 +126,8 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
x0_shape_[i + x0_offset] = shape0[i];
|
||||
} else {
|
||||
auto index = i + x0_offset;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
|
||||
<< ", but got " << (index + 1);
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than "
|
||||
<< kMaxShapeSize << ", but got " << (index + 1);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -144,8 +140,8 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
x1_shape_[i + x1_offset] = shape1[i];
|
||||
} else {
|
||||
auto index = i + x1_offset;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxShapeSize
|
||||
<< ", but got " << (index + 1);
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than "
|
||||
<< kMaxShapeSize << ", but got " << (index + 1);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -204,15 +200,40 @@ bool BroadcastOpGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
|
|||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, BroadcastOpGradGpuKernelMod::BroadCastFunc>> BroadcastOpGradGpuKernelMod::func_list_ =
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&BroadcastOpGradGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&BroadcastOpGradGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&BroadcastOpGradGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&BroadcastOpGradGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&BroadcastOpGradGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> BroadcastOpGradGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -64,7 +64,6 @@ class BroadcastOpGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
using BroadCastFunc = std::function<bool(BroadcastOpGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
BroadcastGradOpType op_type_{BROADCAST_GRAD_TYPE_INVALID};
|
||||
bool need_broadcast_{false};
|
||||
bool is_null_input_{false};
|
||||
|
|
Loading…
Reference in New Issue