forked from mindspore-Ecosystem/mindspore
addcmul support different type on CPU
This commit is contained in:
parent
8eb65baf62
commit
3ca067fbae
|
@ -59,7 +59,8 @@ int AddcmulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
dtype_ = inputs[kInputData]->GetDtype();
|
||||
dtype_value_ = inputs[kInputValue]->GetDtype();
|
||||
input_shape0_ = inputs[kInputData]->GetDeviceShapeAdaptively();
|
||||
input_shape1_ = inputs[kInputX1]->GetDeviceShapeAdaptively();
|
||||
input_shape2_ = inputs[kInputX2]->GetDeviceShapeAdaptively();
|
||||
|
@ -94,17 +95,17 @@ void AddcmulCpuKernelMod::AddcmulMul1(const T *input1, const T *input2, T *outpu
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddcmulCpuKernelMod::AddcmulMul2(const T *input1, const T *input2, T *output) {
|
||||
template <typename T1, typename T2>
|
||||
void AddcmulCpuKernelMod::AddcmulMul2(const T2 *input1, const T1 *input2, T1 *output) {
|
||||
if ((inputx_shape_size_ + inputy_shape_size_ + value_shape_size_) == 0) {
|
||||
output[0] = static_cast<T>(input1[0] * input2[0]);
|
||||
output[0] = static_cast<T1>(input1[0]) * input2[0];
|
||||
} else {
|
||||
BroadcastIterator base_iter(input_shape3_, output_shape_, output_shape_);
|
||||
auto task = [&input1, &input2, &output, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
|
||||
output[i] = static_cast<T1>(input1[iter.GetInputPosA()]) * input2[iter.GetInputPosB()];
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
@ -139,13 +140,29 @@ void AddcmulCpuKernelMod::AddcmulAdd(const T *input1, const T *input2, T *output
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool AddcmulCpuKernelMod::AddcmulCheck(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
if (dtype_value_ == kNumberTypeFloat16) {
|
||||
return AddcmulCompute<T, float16>(inputs, outputs);
|
||||
} else if (dtype_value_ == kNumberTypeFloat32) {
|
||||
return AddcmulCompute<T, float>(inputs, outputs);
|
||||
} else if (dtype_value_ == kNumberTypeFloat64) {
|
||||
return AddcmulCompute<T, double>(inputs, outputs);
|
||||
} else if (dtype_value_ == kNumberTypeInt32) {
|
||||
return AddcmulCompute<T, int>(inputs, outputs);
|
||||
} else if (dtype_value_ == kNumberTypeInt64) {
|
||||
return AddcmulCompute<T, int64_t>(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool AddcmulCpuKernelMod::AddcmulCompute(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *input0 = static_cast<T *>(inputs[kInputData]->addr);
|
||||
const auto *input1 = static_cast<T *>(inputs[kInputX1]->addr);
|
||||
const auto *input2 = static_cast<T *>(inputs[kInputX2]->addr);
|
||||
const auto *input3 = static_cast<T *>(inputs[kInputValue]->addr);
|
||||
auto *output = static_cast<T *>(outputs[kOutputData]->addr);
|
||||
auto *input0 = static_cast<T1 *>(inputs[kInputData]->addr);
|
||||
const auto *input1 = static_cast<T1 *>(inputs[kInputX1]->addr);
|
||||
const auto *input2 = static_cast<T1 *>(inputs[kInputX2]->addr);
|
||||
const auto *input3 = static_cast<T2 *>(inputs[kInputValue]->addr);
|
||||
auto *output = static_cast<T1 *>(outputs[kOutputData]->addr);
|
||||
|
||||
AddcmulMul1(input1, input2, output);
|
||||
AddcmulMul2(input3, output, output);
|
||||
|
@ -159,19 +176,19 @@ bool AddcmulCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
return AddcmulCompute<float>(inputs, outputs);
|
||||
return AddcmulCheck<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
return AddcmulCompute<float16>(inputs, outputs);
|
||||
return AddcmulCheck<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
return AddcmulCompute<double>(inputs, outputs);
|
||||
return AddcmulCheck<double>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
return AddcmulCompute<int>(inputs, outputs);
|
||||
return AddcmulCheck<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
return AddcmulCompute<int64_t>(inputs, outputs);
|
||||
return AddcmulCheck<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeUInt8) {
|
||||
return AddcmulCompute<uint8_t>(inputs, outputs);
|
||||
return AddcmulCheck<uint8_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt8) {
|
||||
return AddcmulCompute<int8_t>(inputs, outputs);
|
||||
return AddcmulCheck<int8_t>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the type of 'x' should be float16, float32, float64, int8, uint8,int32, int64, "
|
||||
|
|
|
@ -46,6 +46,7 @@ class AddcmulCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
private:
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId dtype_value_{kTypeUnknown};
|
||||
std::vector<int64_t> input_shape0_;
|
||||
std::vector<int64_t> input_shape1_;
|
||||
std::vector<int64_t> input_shape2_;
|
||||
|
@ -59,13 +60,16 @@ class AddcmulCpuKernelMod : public NativeCpuKernelMod {
|
|||
ArithmeticParameter mul_para_{};
|
||||
|
||||
template <typename T>
|
||||
bool AddcmulCheck(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool AddcmulCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void AddcmulAdd(const T *input1, const T *input2, T *out);
|
||||
template <typename T>
|
||||
void AddcmulMul1(const T *input1, const T *input2, T *out);
|
||||
template <typename T>
|
||||
void AddcmulMul2(const T *input1, const T *input2, T *out);
|
||||
template <typename T1, typename T2>
|
||||
void AddcmulMul2(const T2 *input1, const T1 *input2, T1 *out);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue