addcmul support different type on CPU

This commit is contained in:
jianghui58 2023-02-18 16:22:23 +08:00
parent 8eb65baf62
commit 3ca067fbae
2 changed files with 40 additions and 19 deletions

View File

@ -59,7 +59,8 @@ int AddcmulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
dtype_ = inputs[kInputData]->GetDtype();
dtype_value_ = inputs[kInputValue]->GetDtype();
input_shape0_ = inputs[kInputData]->GetDeviceShapeAdaptively();
input_shape1_ = inputs[kInputX1]->GetDeviceShapeAdaptively();
input_shape2_ = inputs[kInputX2]->GetDeviceShapeAdaptively();
@ -94,17 +95,17 @@ void AddcmulCpuKernelMod::AddcmulMul1(const T *input1, const T *input2, T *outpu
}
}
template <typename T>
void AddcmulCpuKernelMod::AddcmulMul2(const T *input1, const T *input2, T *output) {
template <typename T1, typename T2>
void AddcmulCpuKernelMod::AddcmulMul2(const T2 *input1, const T1 *input2, T1 *output) {
if ((inputx_shape_size_ + inputy_shape_size_ + value_shape_size_) == 0) {
output[0] = static_cast<T>(input1[0] * input2[0]);
output[0] = static_cast<T1>(input1[0]) * input2[0];
} else {
BroadcastIterator base_iter(input_shape3_, output_shape_, output_shape_);
auto task = [&input1, &input2, &output, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
output[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
output[i] = static_cast<T1>(input1[iter.GetInputPosA()]) * input2[iter.GetInputPosB()];
iter.GenNextPos();
}
};
@ -139,13 +140,29 @@ void AddcmulCpuKernelMod::AddcmulAdd(const T *input1, const T *input2, T *output
}
template <typename T>
bool AddcmulCpuKernelMod::AddcmulCheck(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
if (dtype_value_ == kNumberTypeFloat16) {
return AddcmulCompute<T, float16>(inputs, outputs);
} else if (dtype_value_ == kNumberTypeFloat32) {
return AddcmulCompute<T, float>(inputs, outputs);
} else if (dtype_value_ == kNumberTypeFloat64) {
return AddcmulCompute<T, double>(inputs, outputs);
} else if (dtype_value_ == kNumberTypeInt32) {
return AddcmulCompute<T, int>(inputs, outputs);
} else if (dtype_value_ == kNumberTypeInt64) {
return AddcmulCompute<T, int64_t>(inputs, outputs);
}
return true;
}
template <typename T1, typename T2>
bool AddcmulCpuKernelMod::AddcmulCompute(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto *input0 = static_cast<T *>(inputs[kInputData]->addr);
const auto *input1 = static_cast<T *>(inputs[kInputX1]->addr);
const auto *input2 = static_cast<T *>(inputs[kInputX2]->addr);
const auto *input3 = static_cast<T *>(inputs[kInputValue]->addr);
auto *output = static_cast<T *>(outputs[kOutputData]->addr);
auto *input0 = static_cast<T1 *>(inputs[kInputData]->addr);
const auto *input1 = static_cast<T1 *>(inputs[kInputX1]->addr);
const auto *input2 = static_cast<T1 *>(inputs[kInputX2]->addr);
const auto *input3 = static_cast<T2 *>(inputs[kInputValue]->addr);
auto *output = static_cast<T1 *>(outputs[kOutputData]->addr);
AddcmulMul1(input1, input2, output);
AddcmulMul2(input3, output, output);
@ -159,19 +176,19 @@ bool AddcmulCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (dtype_ == kNumberTypeFloat32) {
return AddcmulCompute<float>(inputs, outputs);
return AddcmulCheck<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
return AddcmulCompute<float16>(inputs, outputs);
return AddcmulCheck<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
return AddcmulCompute<double>(inputs, outputs);
return AddcmulCheck<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
return AddcmulCompute<int>(inputs, outputs);
return AddcmulCheck<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
return AddcmulCompute<int64_t>(inputs, outputs);
return AddcmulCheck<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeUInt8) {
return AddcmulCompute<uint8_t>(inputs, outputs);
return AddcmulCheck<uint8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt8) {
return AddcmulCompute<int8_t>(inputs, outputs);
return AddcmulCheck<int8_t>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'x' should be float16, float32, float64, int8, uint8,int32, int64, "

View File

@ -46,6 +46,7 @@ class AddcmulCpuKernelMod : public NativeCpuKernelMod {
private:
TypeId dtype_{kTypeUnknown};
TypeId dtype_value_{kTypeUnknown};
std::vector<int64_t> input_shape0_;
std::vector<int64_t> input_shape1_;
std::vector<int64_t> input_shape2_;
@ -59,13 +60,16 @@ class AddcmulCpuKernelMod : public NativeCpuKernelMod {
ArithmeticParameter mul_para_{};
template <typename T>
bool AddcmulCheck(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T1, typename T2>
bool AddcmulCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void AddcmulAdd(const T *input1, const T *input2, T *out);
template <typename T>
void AddcmulMul1(const T *input1, const T *input2, T *out);
template <typename T>
void AddcmulMul2(const T *input1, const T *input2, T *out);
template <typename T1, typename T2>
void AddcmulMul2(const T2 *input1, const T1 *input2, T1 *out);
};
} // namespace kernel
} // namespace mindspore