!48823 throw when inputs of batchmatmul need to broadcast on cpu/gpu

Merge pull request !48823 from zhoufeng/xiu-ba-ge
This commit is contained in:
i-robot 2023-02-16 06:26:10 +00:00 committed by Gitee
commit 225751b4e4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 0 deletions

View File

@ -92,6 +92,25 @@ void CheckBatchMatmulInputWhetherCanBeBroadcast(const std::string &name, const S
return;
}
// todo: delete after broadcast shape is supported on cpu/gpu
#if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES))
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kGPUDevice || device_target == kCPUDevice) {
if (x_batch.size() != y_batch.size()) {
MS_EXCEPTION(ValueError) << "For " << name << ", inputs shape cannot be broadcast on CPU/GPU, with x shape "
<< x_shape << ", y shape " << y_shape;
}
for (size_t i = 0; i < x_batch.size(); ++i) {
if (x_batch[i] != y_batch[i]) {
MS_EXCEPTION(ValueError) << "For " << name << ", inputs shape cannot be broadcast on CPU/GPU, with x shape "
<< x_shape << ", y shape " << y_shape;
}
}
}
#endif
size_t min_size = std::min(x_batch.size(), y_batch.size());
for (size_t i = 0; i < min_size; ++i) {
auto x = *(x_batch.rbegin() + i);