adapt some operators dynamic shape and implementations.

This commit is contained in:
y00451588 2022-10-20 13:02:41 +08:00
parent c5bafc2bce
commit 98fcdd7968
12 changed files with 327 additions and 118 deletions

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <utility>
#include <vector>
#include "mindspore/core/ops/concat_offset.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -26,46 +27,70 @@ namespace {
constexpr size_t kConcatOffsetOutputNum = 1;
constexpr size_t kConcatOffsetOutputShapeSize = 2;
} // namespace
void ConcatOffsetCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
cnode_ptr_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
bool ConcatOffsetCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
if (inputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << ", input tensors can not be empty";
return false;
}
auto op_prim = std::dynamic_pointer_cast<ops::ConcatOffset>(base_operator);
MS_ERROR_IF_NULL(op_prim);
if (op_prim->HasAttr(kAttrAxis)) {
axis_ = op_prim->get_axis();
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Concat offset does not support this kernel data type: " << kernel_attr;
MS_LOG(ERROR) << "Concat offset does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
template <typename T>
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cnode_ptr_(kernel_node) is expired. Error no: " << node_;
int ConcatOffsetCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto output_addr = reinterpret_cast<int64_t *>(outputs[0]->addr);
size_t input_num = common::AnfAlgo::GetInputTensorNum(node_);
if (input_num == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", input tensors can not be empty";
output_shape_ = outputs[kIndex0]->GetShapeVector();
if (output_shape_.size() != kConcatOffsetOutputShapeSize) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of output must be " << kConcatOffsetOutputShapeSize
<< ", but got:" << output_shape_.size();
return KRET_RESIZE_FAILED;
}
// check input shapes
std::vector<ShapeVector> input_shapes;
for (size_t i = 0; i < input_num; i++) {
ShapeVector input_shape_i = common::AnfAlgo::GetPrevNodeOutputInferShape(node_, i);
input_shapes.push_back(input_shape_i);
if (input_shape_i.size() != input_shapes[0].size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', input tensors shape's rank must be equal, but got input[0] shape's rank = "
<< input_shapes[0].size() << ", input[" << i << "] shape's rank = " << input_shape_i.size();
if (LongToSize(output_shape_[kIndex0]) != inputs.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the first dimension value of output must be equal to "
"the number of input, but got the first dimension value of output: "
<< output_shape_[kIndex0] << ", and the number of input: " << inputs.size();
return KRET_RESIZE_FAILED;
}
input_shapes_.clear();
for (size_t i = 0; i < inputs.size(); i++) {
ShapeVector shape_i = inputs[i]->GetShapeVector();
input_shapes_.push_back(shape_i);
if (shape_i.size() != input_shapes_[kIndex0].size()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', input tensors shape's rank must be equal, but got input[0] shape's rank = "
<< input_shapes_[kIndex0].size() << ", input[" << i << "] shape's rank = " << shape_i.size();
return KRET_RESIZE_FAILED;
}
}
// check axis
auto x_rank = SizeToLong(input_shapes[0].size());
return KRET_OK;
}
template <typename T>
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<int64_t *>(outputs[kIndex0]->addr);
auto x_rank = SizeToLong(input_shapes_[kIndex0].size());
if (axis_ < -x_rank || axis_ >= x_rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' must be in range [-" << x_rank << ", " << x_rank
<< "), but got " << axis_;
@ -76,27 +101,16 @@ bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
auto axis = LongToSize(axis_);
ShapeVector offset{0};
auto all_shape = input_shapes[0][axis];
auto all_shape = input_shapes_[0][axis];
// cal offset
for (size_t i = 1; i < input_num; i++) {
for (size_t i = 1; i < inputs.size(); i++) {
offset.emplace_back(all_shape);
all_shape += input_shapes[i][axis];
all_shape += input_shapes_[i][axis];
}
auto output_shape = common::AnfAlgo::GetOutputInferShape(node_, 0);
if (output_shape.size() != kConcatOffsetOutputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output must be "
<< kConcatOffsetOutputShapeSize << ", but got:" << output_shape.size();
}
if (LongToSize(output_shape[0]) != input_num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the first dimension value of output must be equal to "
"the number of input, but got the first dimension value of output: "
<< output_shape[0] << ", and the number of input: " << input_num;
}
size_t rank = LongToSize(output_shape[1]);
size_t rank = LongToSize(output_shape_[kIndex1]);
size_t idx = 0;
for (size_t i = 0; i < input_num; ++i) {
for (size_t i = 0; i < inputs.size(); ++i) {
for (size_t j = 0; j < rank; ++j) {
if (j == axis) {
output_addr[idx] = offset[i];

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_OFFSET_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_OFFSET_CPU_KERNEL_H_
#include <map>
#include <vector>
#include <memory>
#include <utility>
@ -25,12 +26,18 @@
namespace mindspore {
namespace kernel {
class ConcatOffsetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class ConcatOffsetCpuKernelMod : public NativeCpuKernelMod {
public:
ConcatOffsetCpuKernelMod() = default;
~ConcatOffsetCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
@ -47,6 +54,8 @@ class ConcatOffsetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
static std::vector<std::pair<KernelAttr, ConcatOffsetFunc>> func_list_;
ConcatOffsetFunc kernel_func_;
int64_t axis_{0};
std::vector<ShapeVector> input_shapes_;
ShapeVector output_shape_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -34,6 +34,7 @@ bool MatrixPowerCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
kernel_name_ = base_operator->name();
dtype_ = inputs[kIndex0]->GetDtype();
auto op_prim = std::dynamic_pointer_cast<ops::MatrixPower>(base_operator);
MS_ERROR_IF_NULL(op_prim);
power_ = op_prim->get_exponent();
return true;
}

View File

@ -17,11 +17,13 @@
#include "plugin/device/cpu/kernel/reduce_scatter_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/hal/device/mpi/mpi_interface.h"
#include "mindspore/core/ops/reduce_scatter.h"
#include "ir/primitive.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kOP = "op";
constexpr auto kRanksGroup = "group";
constexpr size_t kReduceScatterInputsNum = 1;
constexpr size_t kReduceScatterOutputsNum = 1;
@ -29,22 +31,23 @@ constexpr size_t kReduceScatterOutputsNum = 1;
ReduceScatterCpuKernelMod::ReduceScatterCpuKernelMod() : op_type_(kMPIOpTypeSum) {}
void ReduceScatterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(primitive);
auto op = primitive->GetAttr("op");
bool ReduceScatterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto op_prim = std::dynamic_pointer_cast<ops::ReduceScatter>(base_operator);
MS_ERROR_IF_NULL(op_prim);
auto op = op_prim->GetAttr(kOP);
if (op != nullptr) {
op_type_ = GetValue<std::string>(op);
}
auto ranks_group = primitive->GetAttr(kRanksGroup);
if (ranks_group != nullptr) {
ranks_group_ = GetValue<std::vector<int>>(ranks_group);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'group' can not be null, but got empty value.";
auto ranks_group = op_prim->GetAttr(kRanksGroup);
if (ranks_group == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'group' can not be null, but got empty value.";
return false;
}
ranks_group_ = GetValue<std::vector<int>>(ranks_group);
return true;
}
bool ReduceScatterCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -24,12 +24,13 @@
namespace mindspore {
namespace kernel {
class ReduceScatterCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class ReduceScatterCpuKernelMod : public NativeCpuKernelMod {
public:
ReduceScatterCpuKernelMod();
~ReduceScatterCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

View File

@ -17,15 +17,17 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_OTHER_CONCAT_OFFSET_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_OTHER_CONCAT_OFFSET_GPU_KERNEL_H_
#include <map>
#include <vector>
#include <string>
#include "mindspore/core/ops/concat_offset.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class ConcatOffsetGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class ConcatOffsetGpuKernelMod : public NativeGpuKernelMod {
public:
ConcatOffsetGpuKernelMod() {}
~ConcatOffsetGpuKernelMod() override = default;
@ -34,23 +36,42 @@ class ConcatOffsetGpuKernelMod : public DeprecatedNativeGpuKernelMod {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
size_t out_size = out_offset_.size() * sizeof(S);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_device_address, out_offset_.data(), out_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync error in ConcatOffsetGpuKernelMod::Launch");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_device_address, out_offset_.data(), out_size, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync error in ConcatOffsetGpuKernelMod::Launch");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
if (!CheckParam(kernel_node)) {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->GetPrim()->name();
constexpr size_t outputs_num = 1;
if (outputs.size() != outputs_num) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of outputs should be 1, but got " << outputs.size();
return false;
}
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (inputs.size() == 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of input is 0";
return false;
}
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource();
if (inputs[kIndex0]->IsDynamicShape()) {
return KRET_UNKNOWN_SHAPE;
}
auto input_shape = inputs[kIndex0]->GetShapeVector();
auto rank = input_shape.size();
auto rank_int = SizeToInt(rank);
auto axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
auto kernel_ptr = std::dynamic_pointer_cast<ops::ConcatOffset>(base_operator);
int64_t axis = 0;
if (kernel_ptr->HasAttr(kAttrAxis)) {
axis = kernel_ptr->get_axis();
}
if (axis < -rank_int || axis >= rank_int) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in the range [-" << rank << "," << rank
<< "), but got " << axis;
@ -58,23 +79,20 @@ class ConcatOffsetGpuKernelMod : public DeprecatedNativeGpuKernelMod {
if (axis < 0) {
axis += rank_int;
}
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input should be greater than 0";
}
size_t input_num = inputs.size();
for (size_t i = 0; i < input_num; i++) {
int64_t input_size = 1;
auto input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= input_shape[j];
auto input_shape_i = inputs[i]->GetDeviceShapeAdaptively();
for (size_t j = 0; j < input_shape_i.size(); j++) {
input_size *= input_shape_i[j];
}
input_size_list_.push_back(LongToSizeClipNeg(input_size) * sizeof(T));
}
// cal offset
int64_t shape_offset = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0)[axis];
int64_t shape_offset = input_shape[axis];
std::vector<size_t> offset(input_num, 0);
for (size_t i = 1; i < input_num; i++) {
input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
input_shape = inputs[i]->GetShapeVector();
if (input_shape.size() != rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the dimension of input should be equal, but got:"
<< " the dimension of the " << i << "'th input: " << input_shape.size()
@ -84,18 +102,18 @@ class ConcatOffsetGpuKernelMod : public DeprecatedNativeGpuKernelMod {
shape_offset += input_shape[axis];
}
constexpr size_t kConcatOffsetOutputShapeSize = 2;
auto output_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, 0));
auto output_shape = outputs[0]->GetShapeVector();
if (output_shape.size() != kConcatOffsetOutputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be "
<< kConcatOffsetOutputShapeSize << ", but got:" << output_shape.size();
}
if (output_shape[0] != input_num) {
if (output_shape[0] != SizeToInt(input_num)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the first dimension value of output should be equal to "
"the number of input, but got the first dimension value of output: "
<< output_shape[0] << ", and the number of input: " << input_num;
}
if (output_shape[1] != rank) {
if (output_shape[1] != rank_int) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the second dimension value of output should be equal to "
"the dimension of input, but got the second dimension value of output: "
@ -107,27 +125,16 @@ class ConcatOffsetGpuKernelMod : public DeprecatedNativeGpuKernelMod {
out_offset_[i * rank + axis] = offset[i];
}
output_size_list_.push_back(out_offset_.size() * sizeof(S));
InitSizeLists();
return true;
return KRET_OK;
}
void ResetResource() noexcept override {
ResetSizeLists();
void ResetResource() {
input_size_list_.clear();
output_size_list_.clear();
out_offset_.clear();
}
protected:
void InitSizeLists() override {}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs should be 1, but got " << output_num;
}
return true;
}
std::vector<S> out_offset_;
};
} // namespace kernel

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/concat_offset.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
int64_t ConcatOffset::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
void ConcatOffset::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
MIND_API_OPERATOR_IMPL(ConcatOffset, BaseOperator);
REGISTER_PRIMITIVE_C(kNameConcatOffset, ConcatOffset);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CONCAT_OFFSET_H_
#define MINDSPORE_CORE_OPS_CONCAT_OFFSET_H_
#include <vector>
#include <memory>
#include "ops/op_name.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConcatOffset = "ConcatOffset";
/// \brief Computes offsets of concat inputs within its output.
/// Refer to Python API @ref mindspore.ops.ConcatOffset for more details.
class MIND_API ConcatOffset : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ConcatOffset);
/// \brief Constructor.
ConcatOffset() : BaseOperator(kNameConcatOffset) { InitIOName({"N", "axis"}, {"y"}); }
/// \brief Get axis.
/// \return axis.
int64_t get_axis() const;
/// \brief Set axis.
void set_axis(const int64_t axis);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CONCAT_OFFSET_H_

View File

@ -13,15 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/reduce_scatter.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/ms_context.h"
#include "utils/check_convert_utils.h"
#include "mindspore/ccsrc/include/common/utils/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(ReduceScatter, BaseOperator);
void ReduceScatter::set_group(const string &group) {
std::string g = group;
(void)this->AddAttr(kGroup, api::MakeValue(g));
@ -49,6 +50,51 @@ int ReduceScatter::get_rank_size() const {
return static_cast<int>(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameReduceScatter, ReduceScatter);
class ReduceScatterInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_ERROR_IF_NULL_W_RET_VAL(primitive, std::make_shared<abstract::Shape>());
auto value_ptr = primitive->GetAttr(kRankSize);
MS_ERROR_IF_NULL_W_RET_VAL(value_ptr, std::make_shared<abstract::Shape>());
auto rank_size = static_cast<int>(GetValue<int64_t>(value_ptr));
if (rank_size == 0) {
MS_LOG(ERROR) << "For '" << primitive->name() << "', the 'rank_size' can not be zero, but got " << rank_size;
return std::make_shared<abstract::Shape>();
}
auto abstract_shape = input_args[kIndex0]->BuildShape();
MS_ERROR_IF_NULL_W_RET_VAL(abstract_shape, std::make_shared<abstract::Shape>());
if (abstract_shape->IsDynamic()) {
return abstract_shape;
}
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abstract_shape)[kShape];
if (shape.empty() || shape[0] % rank_size != 0) {
MS_LOG(ERROR) << "the first dimension for 'input_shape' must be divided by 'rank_size', but got input_shape[0]: "
<< shape[0] << ", rank_size: " << rank_size;
return std::make_shared<abstract::Shape>();
}
auto out_shape = shape;
out_shape[0] = static_cast<int64_t>(shape[0] / rank_size);
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
auto dtype = input_args[kIndex0]->BuildType();
const std::set<TypePtr> default_valid_types = {kInt8, kInt32, kFloat16, kFloat32};
const std::set<TypePtr> gpu_valid_types = {kBool, kInt8, kInt32, kUInt32, kInt64,
kUInt64, kFloat16, kFloat32, kFloat64};
const std::string input_name = "input";
auto context_ptr = MsContext::GetInstance();
auto is_gpu = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu) {
(void)CheckAndConvertUtils::CheckTensorTypeValid(input_name, dtype, gpu_valid_types, primitive->name());
} else {
(void)CheckAndConvertUtils::CheckTensorTypeValid(input_name, dtype, default_valid_types, primitive->name());
}
return dtype;
}
};
MIND_API_OPERATOR_IMPL(ReduceScatter, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceScatter, prim::kPrimReduceScatter, ReduceScatterInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -130,7 +130,7 @@ class _ScatterOpDynamic(PrimitiveWithCheck):
raise ValueError(f"For '{prim_name}', the 'input_x' does not support dynamic shape, "
f"but got the shape of 'input_x' is {x_shape}.")
# support indices and updates dynamic
if np.any(np.array(indices_shape) == -1) or np.any(np.array(updates_shape) == -1):
if is_shape_unknown(indices_shape) or is_shape_unknown(updates_shape):
pass
elif indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "

View File

@ -389,7 +389,7 @@ class _HostAllGather(PrimitiveWithInfer):
raise NotImplementedError
class ReduceScatter(PrimitiveWithInfer):
class ReduceScatter(Primitive):
"""
Reduces and scatters tensors from the specified communication group.
@ -454,19 +454,6 @@ class ReduceScatter(PrimitiveWithInfer):
self.add_prim_attr('fusion', 0)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
if self.rank_size == 0:
raise ValueError(f"For '{self.name}', the 'rank_size' can not be zero, but got {self.rank_size}.")
if x_shape[0] % self.rank_size != 0:
raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' must be divided by 'rank_size', "
f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.rank_size}.")
x_shape[0] = int(x_shape[0] / self.rank_size)
return x_shape
def infer_dtype(self, x_dtype):
check_collective_target_dtype('x', x_dtype, self.name)
return x_dtype
def __call__(self, tensor):
raise NotImplementedError

View File

@ -0,0 +1,65 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import sys
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
class ConcatOffsetNet(nn.Cell):
def __init__(self):
super().__init__()
self.unique = P.Unique()
self.concat_offset = G.ConcatOffset(3, 0)
self.reshape = P.Reshape()
def construct(self, x, y, z):
x = self.reshape(self.unique(x)[0], (-1, 1, 2, 1))
y = self.reshape(self.unique(y)[0], (-1, 1, 2, 1))
z = self.reshape(self.unique(z)[0], (-1, 1, 2, 1))
out = self.concat_offset((x, y, z))
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_offset_dynamic_gpu():
"""
/// Feature: Concatoffset op dynamic shape
/// Description: Concatoffset forward with dynamic shape
/// Expectation: Euqal to expected value
"""
if sys.platform != 'linux':
return
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32)
x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32)
x3 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32)
net = ConcatOffsetNet()
out = net(x, x2, x3)
expect = np.array([[0, 0, 0, 0],
[3, 0, 0, 0],
[6, 0, 0, 0]])
if isinstance(out, tuple):
assert (np.array(out) == expect).all()
else:
assert (out.asnumpy() == expect).all()