forked from mindspore-Ecosystem/mindspore
!48823 throw when inputs of batchmatmul need to broadcast on cpu/gpu
Merge pull request !48823 from zhoufeng/xiu-ba-ge
This commit is contained in:
commit
225751b4e4
|
@ -92,6 +92,25 @@ void CheckBatchMatmulInputWhetherCanBeBroadcast(const std::string &name, const S
|
|||
return;
|
||||
}
|
||||
|
||||
// todo: delete after broadcast shape is supported on cpu/gpu
|
||||
#if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES))
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == kGPUDevice || device_target == kCPUDevice) {
|
||||
if (x_batch.size() != y_batch.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << name << ", inputs shape cannot be broadcast on CPU/GPU, with x shape "
|
||||
<< x_shape << ", y shape " << y_shape;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_batch.size(); ++i) {
|
||||
if (x_batch[i] != y_batch[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << name << ", inputs shape cannot be broadcast on CPU/GPU, with x shape "
|
||||
<< x_shape << ", y shape " << y_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
size_t min_size = std::min(x_batch.size(), y_batch.size());
|
||||
for (size_t i = 0; i < min_size; ++i) {
|
||||
auto x = *(x_batch.rbegin() + i);
|
||||
|
|
Loading…
Reference in New Issue