WeightQuantBatchMatMulV2 Aclnn adapter
This commit is contained in:
parent
3112d7791d
commit
75b18df4e1
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/opapi/aclnn/weight_quant_batch_matmul_aclnn_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "ir/tensor.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "transform/acl_ir/op_api_convert.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
void WeightQuantBatchMatmulV2Ascend::GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs,
|
||||
const std::vector<KernelTensor *> &outputs) {
|
||||
auto trans_x = transform::ConvertKernelTensor<bool>(inputs[kIndex7]);
|
||||
auto trans_weight = transform::ConvertKernelTensor<bool>(inputs[kIndex8]);
|
||||
auto antiquant_group_size = transform::ConvertKernelTensor<int64_t>(inputs[kIndex9]);
|
||||
|
||||
input_x_ = std::pair<KernelTensor *, bool>(inputs[kIndex0], trans_x);
|
||||
input_weight_ = std::pair<KernelTensor *, bool>(inputs[kIndex1], trans_weight);
|
||||
GetWorkspaceForResize(input_x_, input_weight_, inputs[kIndex2], inputs[kIndex3], inputs[kIndex4], inputs[kIndex5],
|
||||
inputs[kIndex6], antiquant_group_size, outputs[kIndex0]);
|
||||
}
|
||||
|
||||
bool WeightQuantBatchMatmulV2Ascend::Launch(const std::vector<KernelTensor *> &inputs,
|
||||
const std::vector<KernelTensor *> &workspace,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
|
||||
auto antiquant_group_size = transform::ConvertKernelTensor<int64_t>(inputs[kIndex9]);
|
||||
input_x_.first = inputs[kIndex0];
|
||||
input_weight_.first = inputs[kIndex1];
|
||||
ParseGenExecutor(GEN_EXECUTOR_BOOST(op_type_, hash_id_, input_x_, input_weight_, inputs[kIndex2], inputs[kIndex3],
|
||||
inputs[kIndex4], inputs[kIndex5], inputs[kIndex6], antiquant_group_size,
|
||||
outputs[kIndex0]));
|
||||
RunOp(stream_ptr, workspace);
|
||||
return true;
|
||||
}
|
||||
MS_ACLNN_KERNEL_FACTORY_REG(WeightQuantBatchMatmul, WeightQuantBatchMatmulV2Ascend);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_WEIGHT_QUANT_BATCH_MATMUL_ACLNN_KERNEL_MOD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_WEIGHT_QUANT_BATCH_MATMUL_ACLNN_KERNEL_MOD_H_
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ops/base_operator.h"
|
||||
#include "plugin/device/ascend/kernel/opapi/aclnn_kernel_mod.h"
|
||||
#include "transform/acl_ir/acl_convert.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class WeightQuantBatchMatmulV2Ascend : public AclnnKernelMod {
|
||||
public:
|
||||
WeightQuantBatchMatmulV2Ascend() : AclnnKernelMod(std::move("aclnnWeightQuantBatchMatmulV2")) {}
|
||||
~WeightQuantBatchMatmulV2Ascend() = default;
|
||||
bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) override;
|
||||
void GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
DEFINE_GET_WORKSPACE_FOR_RESIZE()
|
||||
|
||||
std::pair<KernelTensor *, bool> input_x_;
|
||||
std::pair<KernelTensor *, bool> input_weight_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_WEIGHT_QUANT_BATCH_MATMUL_ACLNN_KERNEL_MOD_H_
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/pyboost/customize/weight_quant_batch_matmul.h"
|
||||
#include <memory>
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
|
||||
#include "plugin/device/ascend/kernel/pyboost/auto_generate/transpose.h"
|
||||
#include "kernel/pyboost/op_register.h"
|
||||
#include "kernel/pyboost/pyboost_utils.h"
|
||||
#include "plugin/device/ascend/kernel/pyboost/aclnn_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace pyboost {
|
||||
namespace {
|
||||
void WeightQuantBatchMatmulV2AscendCall(const std::shared_ptr<OpRunner> &op,
|
||||
const device::DeviceContext *device_context, const TensorPtr &x_tensor,
|
||||
const TensorPtr &weight_tensor, const TensorPtr &antiquant_scale_tensor,
|
||||
const std::optional<TensorPtr> &antiquant_offset_tensor,
|
||||
const std::optional<TensorPtr> &quant_scale_tensor,
|
||||
const std::optional<TensorPtr> &quant_offset_tensor,
|
||||
const std::optional<TensorPtr> &bias_tensor, int64_t antiquant_group_size,
|
||||
const std::vector<tensor::TensorPtr> &outputs) {
|
||||
MS_LOG(DEBUG) << "Call start";
|
||||
LAUNCH_ACLNN(aclnnWeightQuantBatchMatmulV2, device_context, op->stream_id(), x_tensor, weight_tensor,
|
||||
antiquant_scale_tensor, antiquant_offset_tensor, quant_scale_tensor, quant_offset_tensor, bias_tensor,
|
||||
antiquant_group_size, outputs[0]);
|
||||
MS_LOG(DEBUG) << "Launch end";
|
||||
}
|
||||
ValueTuplePtr GetTransposePerm(const TensorPtr &weight_tensor) {
|
||||
const auto &shape = weight_tensor->shape();
|
||||
int64_t size = shape.size();
|
||||
std::vector<ValuePtr> perm(size);
|
||||
if (size < 2) {
|
||||
auto zero = std::make_shared<Int64Imm>(0);
|
||||
perm[0] = MakeValue(zero);
|
||||
return std::make_shared<ValueTuple>(perm);
|
||||
}
|
||||
perm[size - 1] = MakeValue(size - 2);
|
||||
perm[size - 2] = MakeValue(size - 1);
|
||||
for (int64_t i = 0; i < size - 2; ++i) {
|
||||
perm[i] = MakeValue(i);
|
||||
}
|
||||
return std::make_shared<ValueTuple>(perm);
|
||||
}
|
||||
} // namespace
|
||||
tensor::TensorPtr WeightQuantBatchMatmulV2AscendCustomize(
|
||||
const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor, const TensorPtr &weight_tensor,
|
||||
const TensorPtr &antiquant_scale_tensor, const std::optional<TensorPtr> &antiquant_offset_tensor,
|
||||
const std::optional<TensorPtr> &quant_scale_tensor, const std::optional<TensorPtr> &quant_offset_tensor,
|
||||
const std::optional<TensorPtr> &bias_tensor, const BoolImmPtr &transpose_x, const BoolImmPtr &transpose_weight,
|
||||
const Int64ImmPtr &antiquant_group_size) {
|
||||
OpRunner::InferOpOutput(op, x_tensor, weight_tensor, antiquant_scale_tensor, antiquant_offset_tensor,
|
||||
quant_scale_tensor, quant_offset_tensor, bias_tensor, transpose_x, transpose_weight,
|
||||
antiquant_group_size);
|
||||
auto transpose_x_imm = GetValue<bool>(transpose_x);
|
||||
auto transpose_weight_imm = GetValue<bool>(transpose_weight);
|
||||
auto antiquant_group_size_imm = GetValue<int64_t>(antiquant_group_size);
|
||||
|
||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor, weight_tensor, antiquant_scale_tensor,
|
||||
antiquant_offset_tensor, quant_scale_tensor, quant_offset_tensor, bias_tensor);
|
||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||
|
||||
auto device_context = op->device_context();
|
||||
TensorPtr x_tensor_trans = x_tensor;
|
||||
if (transpose_x_imm) {
|
||||
const auto &device_name = device_context->device_context_key_.device_name_;
|
||||
auto transpose_op = CREATE_PYBOOST_OP(Transpose, device_name);
|
||||
x_tensor_trans = transpose_op->Call(x_tensor_trans, GetTransposePerm(x_tensor_trans));
|
||||
}
|
||||
TensorPtr weight_tensor_trans = weight_tensor;
|
||||
if (transpose_weight_imm) {
|
||||
const auto &device_name = device_context->device_context_key_.device_name_;
|
||||
auto transpose_op = CREATE_PYBOOST_OP(Transpose, device_name);
|
||||
weight_tensor_trans = transpose_op->Call(weight_tensor_trans, GetTransposePerm(weight_tensor_trans));
|
||||
}
|
||||
PyBoostUtils::DispatchRun(std::make_shared<runtime::PyBoostDeviceTask>(
|
||||
[op, x_tensor_trans, weight_tensor_trans, antiquant_scale_tensor, antiquant_offset_tensor, quant_scale_tensor,
|
||||
quant_offset_tensor, bias_tensor, transpose_x_imm, transpose_weight_imm, antiquant_group_size_imm]() {
|
||||
MS_LOG(DEBUG) << "Run device task weight quant batchMatmul v2 start";
|
||||
auto device_context = op->device_context();
|
||||
const auto &outputs = op->outputs();
|
||||
// Malloc for input tensors
|
||||
PyBoostUtils::MallocOpInputs(device_context, x_tensor_trans, weight_tensor_trans, antiquant_scale_tensor,
|
||||
antiquant_offset_tensor, quant_scale_tensor, quant_offset_tensor, bias_tensor);
|
||||
// Malloc for output tensors
|
||||
PyBoostUtils::MallocOpOutputs(device_context, outputs);
|
||||
WeightQuantBatchMatmulV2AscendCall(op, device_context, x_tensor_trans, weight_tensor_trans,
|
||||
antiquant_scale_tensor, antiquant_offset_tensor, quant_scale_tensor,
|
||||
quant_offset_tensor, bias_tensor, antiquant_group_size_imm, outputs);
|
||||
MS_LOG(DEBUG) << "Run device task weight quant batchMatmul v2 end";
|
||||
}));
|
||||
return op->output(0);
|
||||
}
|
||||
} // namespace pyboost
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_PYBOOST_CUSTOMIZE_WEIGHT_QUANT_V2_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_PYBOOST_CUSTOMIZE_WEIGHT_QUANT_V2_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/value.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "kernel/pyboost/op_runner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace pyboost {
|
||||
tensor::TensorPtr WeightQuantBatchMatmulV2AscendCustomize(
|
||||
const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor, const TensorPtr &weight_tensor,
|
||||
const TensorPtr &antiquant_scale_tensor, const std::optional<TensorPtr> &antiquant_offset_tensor,
|
||||
const std::optional<TensorPtr> &quant_scale_tensor, const std::optional<TensorPtr> &quant_offset_tensor,
|
||||
const std::optional<TensorPtr> &bias_tensor, const BoolImmPtr &transpose_x, const BoolImmPtr &transpose_weight,
|
||||
const Int64ImmPtr &antiquant_group_size);
|
||||
} // namespace pyboost
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_PYBOOST_CUSTOMIZE_WEIGHT_QUANT_V2_H_
|
|
@ -34,3 +34,6 @@ weight_quant_batch_matmul:
|
|||
returns:
|
||||
y:
|
||||
dtype: tensor
|
||||
dispatch:
|
||||
enable: True
|
||||
Ascend: WeightQuantBatchMatmulV2Ascend
|
|
@ -53,8 +53,18 @@ BaseShapePtr WeightQuantBatchMatmulFuncImpl::InferShape(const PrimitivePtr &prim
|
|||
|
||||
ValuePtr transpose_x_ptr = input_args[kInputTransposeX]->GetValue();
|
||||
ValuePtr transpose_weight_ptr = input_args[kInputTransposeWeight]->GetValue();
|
||||
bool transpose_x = GetValue<bool>(transpose_x_ptr);
|
||||
bool transpose_weight = GetValue<bool>(transpose_weight_ptr);
|
||||
auto transpose_x_any = GetScalarValue<bool>(transpose_x_ptr);
|
||||
if (!transpose_x_any.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< "', input 'transpose_x' has no value:" << input_args[kInputTransposeX]->ToString();
|
||||
}
|
||||
bool transpose_x = transpose_x_any.value();
|
||||
auto transpose_weight_any = GetScalarValue<bool>(transpose_weight_ptr);
|
||||
if (!transpose_weight_any.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< "', input 'transpose_weight' has no value:" << input_args[kInputTransposeWeight]->ToString();
|
||||
}
|
||||
bool transpose_weight = transpose_weight_any.value();
|
||||
|
||||
auto x_shp = x_shape_map[kShape];
|
||||
auto weight_shp = weight_shape_map[kShape];
|
||||
|
|
|
@ -67,6 +67,7 @@ SoftplusExt: 'aclnnSoftplus'
|
|||
SoftplusGradExt: 'aclnnSoftplusBackward'
|
||||
StackExt: 'aclnnStack'
|
||||
TopkExt: 'aclnnTopk'
|
||||
WeightQuantBatchMatmul: 'aclnnWeightQuantBatchMatmulV2'
|
||||
# 1
|
||||
|
||||
# 2
|
||||
|
|
|
@ -177,20 +177,27 @@ class DequantBMMCell(Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_weight_quant_bmm_cell_as_antiquant_1p():
|
||||
@pytest.mark.parametrize('mode', ['GE', 'KBK', 'pynative'])
|
||||
def test_weight_quant_bmm_cell_as_antiquant_1p(mode):
|
||||
"""
|
||||
Feature: weight quant bmm cell for antiquant
|
||||
Description: test antiquant using weight quant bmm cell
|
||||
Expectation: accuracy in tolerance
|
||||
"""
|
||||
|
||||
context.set_context(device_target="Ascend", mode=GRAPH_MODE)
|
||||
weight = np.array([[100, 200], [10, 25]]).astype(np.int8)
|
||||
activation = np.array([[0.1, 1.], [0.5, 2.4]]).astype(np.float16)
|
||||
weight = np.array([[100, 200, 100], [10, 25, 10]]).astype(np.int8)
|
||||
activation = np.array([[0.1, 1., 0.1], [0.5, 2.4, 0.5]]).astype(np.float16)
|
||||
scale = np.array([0.5, 0.27]).astype(np.float16)
|
||||
offset = np.array([-127, -10]).astype(np.float16)
|
||||
expect = np.matmul(activation, NumpyQuantOps.anti_quant(weight, scale, offset))
|
||||
wqmm_cell = AntiquantBMMCell(scale, offset)
|
||||
expect = np.matmul(activation, NumpyQuantOps.anti_quant(np.transpose(weight), scale, offset))
|
||||
wqmm_cell = AntiquantBMMCell(scale, offset, dtype.float16, False, True)
|
||||
if mode == 'KBK':
|
||||
context.set_context(device_target="Ascend", mode=GRAPH_MODE)
|
||||
wqmm_cell.set_jit_config(JitConfig(jit_level='O0'))
|
||||
elif mode == 'GE':
|
||||
context.set_context(device_target="Ascend", mode=GRAPH_MODE)
|
||||
else:
|
||||
context.set_context(device_target="Ascend", mode=PYNATIVE_MODE)
|
||||
t_activation = Tensor(activation, dtype=dtype.float16)
|
||||
p_weight = Parameter(Tensor(weight, dtype=dtype.int8), 'weight')
|
||||
fact = wqmm_cell(t_activation, p_weight).asnumpy()
|
||||
|
|
Loading…
Reference in New Issue