support cpu dynamic_broadcastto

This commit is contained in:
fangzehua 2021-12-14 15:28:03 +08:00
parent ac39b6ebca
commit aeb74b0d64
2 changed files with 30 additions and 11 deletions

View File

@ -20,7 +20,6 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {
constexpr size_t kBroadcastToInputsNum = 1;
constexpr size_t kBroadcastToOutputsNum = 1; constexpr size_t kBroadcastToOutputsNum = 1;
} // namespace } // namespace
@ -32,6 +31,21 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
size_t input_shape_size = input_shape_.size(); size_t input_shape_size = input_shape_.size();
size_t output_shape_size = output_shape_.size(); size_t output_shape_size = output_shape_.size();
for (size_t i = 0; i < input_shape_size; ++i) {
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
}
for (size_t i = 0; i < output_shape_size; ++i) {
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
}
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
}
template <typename T>
void BroadcastToCPUKernel<T>::CheckArgs() {
size_t input_shape_size = input_shape_.size();
size_t output_shape_size = output_shape_.size();
if (output_shape_size < input_shape_size) { if (output_shape_size < input_shape_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', input tensor 'input_x' and target shape 'shape' can't " << "', input tensor 'input_x' and target shape 'shape' can't "
@ -56,22 +70,13 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
<< Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_); << Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_);
} }
} }
for (size_t i = 0; i < input_shape_size; ++i) {
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
}
for (size_t i = 0; i < output_shape_size; ++i) {
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
}
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
} }
template <typename T> template <typename T>
bool BroadcastToCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool BroadcastToCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBroadcastToInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
CheckArgs();
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr); const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr); auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
int status = static_cast<int>(NNACL_OK); int status = static_cast<int>(NNACL_OK);

View File

@ -36,6 +36,8 @@ class BroadcastToCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
void CheckArgs();
private: private:
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
@ -48,6 +50,18 @@ MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).Add
BroadcastToCPUKernel, int); BroadcastToCPUKernel, int);
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastToCPUKernel, bool); BroadcastToCPUKernel, bool);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
BroadcastToCPUKernel, float);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastToCPUKernel, int);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastToCPUKernel, bool);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore