forked from mindspore-Ecosystem/mindspore
support cpu dynamic_broadcastto
This commit is contained in:
parent
ac39b6ebca
commit
aeb74b0d64
|
@ -20,7 +20,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kBroadcastToInputsNum = 1;
|
|
||||||
constexpr size_t kBroadcastToOutputsNum = 1;
|
constexpr size_t kBroadcastToOutputsNum = 1;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -32,6 +31,21 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||||
size_t input_shape_size = input_shape_.size();
|
size_t input_shape_size = input_shape_.size();
|
||||||
size_t output_shape_size = output_shape_.size();
|
size_t output_shape_size = output_shape_.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input_shape_size; ++i) {
|
||||||
|
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < output_shape_size; ++i) {
|
||||||
|
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
|
||||||
|
}
|
||||||
|
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
|
||||||
|
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void BroadcastToCPUKernel<T>::CheckArgs() {
|
||||||
|
size_t input_shape_size = input_shape_.size();
|
||||||
|
size_t output_shape_size = output_shape_.size();
|
||||||
if (output_shape_size < input_shape_size) {
|
if (output_shape_size < input_shape_size) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||||
<< "', input tensor 'input_x' and target shape 'shape' can't "
|
<< "', input tensor 'input_x' and target shape 'shape' can't "
|
||||||
|
@ -56,22 +70,13 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
<< Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_);
|
<< Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < input_shape_size; ++i) {
|
|
||||||
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < output_shape_size; ++i) {
|
|
||||||
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
|
|
||||||
}
|
|
||||||
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
|
|
||||||
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool BroadcastToCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool BroadcastToCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBroadcastToInputsNum, kernel_name_);
|
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
|
||||||
|
CheckArgs();
|
||||||
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
int status = static_cast<int>(NNACL_OK);
|
int status = static_cast<int>(NNACL_OK);
|
||||||
|
|
|
@ -36,6 +36,8 @@ class BroadcastToCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
|
void CheckArgs();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<size_t> input_shape_;
|
std::vector<size_t> input_shape_;
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
|
@ -48,6 +50,18 @@ MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).Add
|
||||||
BroadcastToCPUKernel, int);
|
BroadcastToCPUKernel, int);
|
||||||
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||||
BroadcastToCPUKernel, bool);
|
BroadcastToCPUKernel, bool);
|
||||||
|
MS_REG_CPU_KERNEL_T(
|
||||||
|
DynamicBroadcastTo,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
BroadcastToCPUKernel, float);
|
||||||
|
MS_REG_CPU_KERNEL_T(
|
||||||
|
DynamicBroadcastTo,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
BroadcastToCPUKernel, int);
|
||||||
|
MS_REG_CPU_KERNEL_T(
|
||||||
|
DynamicBroadcastTo,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||||
|
BroadcastToCPUKernel, bool);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue