!45125 [assistant][ops] Add SparseSlice & SparseSliceGrad code fix
Merge pull request !45125 from YR0717/dev-fix
This commit is contained in:
commit
4372fff36a
|
@ -13,10 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_slice_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
|
||||
|
@ -27,8 +27,8 @@ using complex64 = std::complex<float>;
|
|||
using complex128 = std::complex<double>;
|
||||
constexpr int64_t kSparseSliceInputsNum = 5;
|
||||
constexpr int64_t kSparseSliceOutputsNum = 3;
|
||||
constexpr int64_t dim0num = 1;
|
||||
constexpr int64_t dim1num = 2;
|
||||
constexpr int64_t kDim0Num = 1;
|
||||
constexpr int64_t kDim1Num = 2;
|
||||
|
||||
#define ADD_KERNEL(dtype, type) \
|
||||
{ \
|
||||
|
@ -68,23 +68,23 @@ int SparseSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
const auto input_start_shape = inputs[kIndex3]->GetShapeVector();
|
||||
const auto input_size_shape = inputs[kIndex4]->GetShapeVector();
|
||||
|
||||
if (input_indices_shape.size() != dim1num) {
|
||||
if (input_indices_shape.size() != kDim1Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_indices_shape' must be 2D Tensor "
|
||||
<< ", but got " << input_indices_shape.size() << "-D";
|
||||
}
|
||||
if (input_values_shape.size() != dim0num) {
|
||||
if (input_values_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_values_shape' must be 1D Tensor "
|
||||
<< ", but got " << input_values_shape.size() << "-D";
|
||||
}
|
||||
if (input_shape_shape.size() != dim0num) {
|
||||
if (input_shape_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_shape_shape' must be 1D Tensor "
|
||||
<< ", but got " << input_shape_shape.size() << "-D";
|
||||
}
|
||||
if (input_start_shape.size() != dim0num) {
|
||||
if (input_start_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_start_shape' must be 1D Tensor "
|
||||
<< ", but got " << input_start_shape.size() << "-D";
|
||||
}
|
||||
if (input_size_shape.size() != dim0num) {
|
||||
if (input_size_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'input_size_shape' must be 1D Tensor "
|
||||
<< ", but got " << input_size_shape.size() << "-D";
|
||||
}
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_slice_grad_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
|
||||
|
@ -27,8 +27,8 @@ using complex64 = std::complex<float>;
|
|||
using complex128 = std::complex<double>;
|
||||
constexpr int64_t kSparseSliceGradInputsNum = 4;
|
||||
constexpr int64_t kSparseSliceGradOutputsNum = 1;
|
||||
constexpr int64_t dim0num = 1;
|
||||
constexpr int64_t dim1num = 2;
|
||||
constexpr int64_t kDim0Num = 1;
|
||||
constexpr int64_t kDim1Num = 2;
|
||||
|
||||
#define ADD_KERNEL(dtype, type) \
|
||||
{ \
|
||||
|
@ -65,19 +65,19 @@ int SparseSliceGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
const auto new_indices_shape = inputs[kIndex3]->GetShapeVector();
|
||||
|
||||
// Check shape
|
||||
if (backprop_val_grad_shape.size() != dim0num) {
|
||||
if (backprop_val_grad_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'brackprop_val_gard' must be 1D Tensor "
|
||||
<< ", but got " << backprop_val_grad_shape.size() << "-D";
|
||||
}
|
||||
if (indices_shape.size() != dim1num) {
|
||||
if (indices_shape.size() != kDim1Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'indices_shape' must be 2D Tensor "
|
||||
<< ", but got " << indices_shape.size() << "-D";
|
||||
}
|
||||
if (start_shape.size() != dim0num) {
|
||||
if (start_shape.size() != kDim0Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'start_shape' must be 1D Tensor "
|
||||
<< ", but got " << start_shape.size() << "-D";
|
||||
}
|
||||
if (new_indices_shape.size() != dim1num) {
|
||||
if (new_indices_shape.size() != kDim1Num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'new_indices_shape' must be 2D Tensor "
|
||||
<< ", but got " << new_indices_shape.size() << "-D";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue