SparseSparseMaximum Check if A shape matchs with B shape
This commit is contained in:
parent
73a7bc9b28
commit
36aa35e15e
|
@ -201,6 +201,17 @@ void SparseSparseMaximumCpuKernelMod::CheckInputShape(const std::vector<KernelTe
|
|||
}
|
||||
}
|
||||
|
||||
void SparseSparseMaximumCpuKernelMod::CheckShapeMatch(const std::vector<AddressPtr> &inputs) {
|
||||
auto a_shape_ptr = reinterpret_cast<int64_t *>(inputs[kInputa_shapes]->addr);
|
||||
auto b_shape_ptr = reinterpret_cast<int64_t *>(inputs[kInputb_shapes]->addr);
|
||||
for (int64_t i = 0; i < num_dims_; ++i) {
|
||||
if (a_shape_ptr[i] != b_shape_ptr[i]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', operand's shapes do not match at index " << i
|
||||
<< ", got value: " << a_shape_ptr[i] << ", and " << b_shape_ptr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SparseSparseMaximumCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
@ -227,6 +238,7 @@ bool SparseSparseMaximumCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
|||
const int64_t a_nnz = a_nnz_;
|
||||
const int64_t num_dims = num_dims_;
|
||||
const int64_t b_nnz = b_nnz_;
|
||||
CheckShapeMatch(inputs);
|
||||
|
||||
auto a_values_ptr = reinterpret_cast<T *>(inputs[kInputa_values]->addr);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 1> a_values_size(EIGEN_SHAPE_CAST(a_values_shape0_));
|
||||
|
|
|
@ -51,6 +51,7 @@ class SparseSparseMaximumCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
void CheckInputShape(const std::vector<KernelTensorPtr> &inputs, const int64_t a_nnz, const int64_t b_nnz,
|
||||
const int64_t num_dims);
|
||||
void CheckShapeMatch(const std::vector<AddressPtr> &inputs);
|
||||
|
||||
std::vector<KernelTensorPtr> outputs_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
|
|
Loading…
Reference in New Issue