SparseSparseMaximum Check if A shape matchs with B shape

This commit is contained in:
hw_hz 2022-11-07 19:34:44 +08:00
parent 73a7bc9b28
commit 36aa35e15e
2 changed files with 13 additions and 0 deletions

View File

@ -201,6 +201,17 @@ void SparseSparseMaximumCpuKernelMod::CheckInputShape(const std::vector<KernelTe
}
}
void SparseSparseMaximumCpuKernelMod::CheckShapeMatch(const std::vector<AddressPtr> &inputs) {
auto a_shape_ptr = reinterpret_cast<int64_t *>(inputs[kInputa_shapes]->addr);
auto b_shape_ptr = reinterpret_cast<int64_t *>(inputs[kInputb_shapes]->addr);
for (int64_t i = 0; i < num_dims_; ++i) {
if (a_shape_ptr[i] != b_shape_ptr[i]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', operand's shapes do not match at index " << i
<< ", got value: " << a_shape_ptr[i] << ", and " << b_shape_ptr[i];
}
}
}
bool SparseSparseMaximumCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
@ -227,6 +238,7 @@ bool SparseSparseMaximumCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
const int64_t a_nnz = a_nnz_;
const int64_t num_dims = num_dims_;
const int64_t b_nnz = b_nnz_;
CheckShapeMatch(inputs);
auto a_values_ptr = reinterpret_cast<T *>(inputs[kInputa_values]->addr);
Eigen::DSizes<Eigen::DenseIndex, 1> a_values_size(EIGEN_SHAPE_CAST(a_values_shape0_));

View File

@ -51,6 +51,7 @@ class SparseSparseMaximumCpuKernelMod : public NativeCpuKernelMod {
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void CheckInputShape(const std::vector<KernelTensorPtr> &inputs, const int64_t a_nnz, const int64_t b_nnz,
const int64_t num_dims);
void CheckShapeMatch(const std::vector<AddressPtr> &inputs);
std::vector<KernelTensorPtr> outputs_;
TypeId dtype_{kTypeUnknown};