!9956 support broadcast of Mul cpu op

From: @wuxuejian
Reviewed-by: @kisnwang,@oacjiewen
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-15 19:30:43 +08:00 committed by Gitee
commit 28696b9d4d
1 changed files with 6 additions and 2 deletions

View File

@ -32,7 +32,9 @@ void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t s
template <typename T>
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] + input2[idx[1]];
}
}
@ -48,7 +50,9 @@ void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t s
template <typename T>
void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] * input2[i];
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] * input2[idx[1]];
}
}