forked from mindspore-Ecosystem/mindspore
commit
d475c95bdf
|
@ -55,12 +55,22 @@ int SparseSegmentSqrtNCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
indices_shape_ = inputs.at(kIndex1)->GetDeviceShapeAdaptively();
|
||||
segment_ids_shape_ = inputs.at(kIndex2)->GetDeviceShapeAdaptively();
|
||||
y_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(x_shape_, kernel_name_, "x_shape_") ||
|
||||
CHECK_SHAPE_NULL(indices_shape_, kernel_name_, "indices_shape_") ||
|
||||
CHECK_SHAPE_NULL(segment_ids_shape_, kernel_name_, "segment_ids_shape_");
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool SparseSegmentSqrtNCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
|
|
|
@ -54,6 +54,7 @@ class SparseSegmentSqrtNCpuKernelMod : public NativeCpuKernelMod {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId dtype1_{kTypeUnknown};
|
||||
TypeId dtype2_{kTypeUnknown};
|
||||
bool is_null_input_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue