!45125 [assistant][ops] Add SparseSlice & SparseSliceGrad code fix

Merge pull request !45125 from YR0717/dev-fix
This commit is contained in:
i-robot 2022-11-04 06:47:07 +00:00 committed by Gitee
commit 4372fff36a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 17 additions and 17 deletions

View File

@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <algorithm>
#include <complex>
#include "plugin/device/cpu/kernel/sparse_slice_cpu_kernel.h" #include "plugin/device/cpu/kernel/sparse_slice_cpu_kernel.h"
#include <algorithm>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h" #include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
@ -27,8 +27,8 @@ using complex64 = std::complex<float>;
using complex128 = std::complex<double>; using complex128 = std::complex<double>;
constexpr int64_t kSparseSliceInputsNum = 5; constexpr int64_t kSparseSliceInputsNum = 5;
constexpr int64_t kSparseSliceOutputsNum = 3; constexpr int64_t kSparseSliceOutputsNum = 3;
constexpr int64_t dim0num = 1; constexpr int64_t kDim0Num = 1;
constexpr int64_t dim1num = 2; constexpr int64_t kDim1Num = 2;
#define ADD_KERNEL(dtype, type) \ #define ADD_KERNEL(dtype, type) \
{ \ { \
@ -68,23 +68,23 @@ int SparseSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
const auto input_start_shape = inputs[kIndex3]->GetShapeVector(); const auto input_start_shape = inputs[kIndex3]->GetShapeVector();
const auto input_size_shape = inputs[kIndex4]->GetShapeVector(); const auto input_size_shape = inputs[kIndex4]->GetShapeVector();
if (input_indices_shape.size() != dim1num) { if (input_indices_shape.size() != kDim1Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_indices_shape' must be 2D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_indices_shape' must be 2D Tensor "
<< ", but got " << input_indices_shape.size() << "-D"; << ", but got " << input_indices_shape.size() << "-D";
} }
if (input_values_shape.size() != dim0num) { if (input_values_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_values_shape' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_values_shape' must be 1D Tensor "
<< ", but got " << input_values_shape.size() << "-D"; << ", but got " << input_values_shape.size() << "-D";
} }
if (input_shape_shape.size() != dim0num) { if (input_shape_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_shape_shape' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_shape_shape' must be 1D Tensor "
<< ", but got " << input_shape_shape.size() << "-D"; << ", but got " << input_shape_shape.size() << "-D";
} }
if (input_start_shape.size() != dim0num) { if (input_start_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_start_shape' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_start_shape' must be 1D Tensor "
<< ", but got " << input_start_shape.size() << "-D"; << ", but got " << input_start_shape.size() << "-D";
} }
if (input_size_shape.size() != dim0num) { if (input_size_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_size_shape' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_size_shape' must be 1D Tensor "
<< ", but got " << input_size_shape.size() << "-D"; << ", but got " << input_size_shape.size() << "-D";
} }

View File

@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <algorithm>
#include <complex>
#include "plugin/device/cpu/kernel/sparse_slice_grad_cpu_kernel.h" #include "plugin/device/cpu/kernel/sparse_slice_grad_cpu_kernel.h"
#include <algorithm>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h" #include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
@ -27,8 +27,8 @@ using complex64 = std::complex<float>;
using complex128 = std::complex<double>; using complex128 = std::complex<double>;
constexpr int64_t kSparseSliceGradInputsNum = 4; constexpr int64_t kSparseSliceGradInputsNum = 4;
constexpr int64_t kSparseSliceGradOutputsNum = 1; constexpr int64_t kSparseSliceGradOutputsNum = 1;
constexpr int64_t dim0num = 1; constexpr int64_t kDim0Num = 1;
constexpr int64_t dim1num = 2; constexpr int64_t kDim1Num = 2;
#define ADD_KERNEL(dtype, type) \ #define ADD_KERNEL(dtype, type) \
{ \ { \
@ -65,19 +65,19 @@ int SparseSliceGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const auto new_indices_shape = inputs[kIndex3]->GetShapeVector(); const auto new_indices_shape = inputs[kIndex3]->GetShapeVector();
// Check shape // Check shape
if (backprop_val_grad_shape.size() != dim0num) { if (backprop_val_grad_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'brackprop_val_gard' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'brackprop_val_gard' must be 1D Tensor "
<< ", but got " << backprop_val_grad_shape.size() << "-D"; << ", but got " << backprop_val_grad_shape.size() << "-D";
} }
if (indices_shape.size() != dim1num) { if (indices_shape.size() != kDim1Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'indices_shape' must be 2D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'indices_shape' must be 2D Tensor "
<< ", but got " << indices_shape.size() << "-D"; << ", but got " << indices_shape.size() << "-D";
} }
if (start_shape.size() != dim0num) { if (start_shape.size() != kDim0Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'start_shape' must be 1D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'start_shape' must be 1D Tensor "
<< ", but got " << start_shape.size() << "-D"; << ", but got " << start_shape.size() << "-D";
} }
if (new_indices_shape.size() != dim1num) { if (new_indices_shape.size() != kDim1Num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'new_indices_shape' must be 2D Tensor " MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'new_indices_shape' must be 2D Tensor "
<< ", but got " << new_indices_shape.size() << "-D"; << ", but got " << new_indices_shape.size() << "-D";
} }