!12149 quantum simulation

From: @donghufeng
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-22 14:21:13 +08:00 committed by Gitee
commit 2088d38610
39 changed files with 2237 additions and 59 deletions

View File

@ -0,0 +1,25 @@
set(projectq_CXXFLAGS "-fopenmp -O2 -ffast-mast -march=native -DINTRIN")
set(projectq_CFLAGS "-fopenmp -O2 -ffast-mast -march=native -DINTRIN")
if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ProjectQ/repository/archive/v0.5.1.tar.gz")
set(MD5 "d874e93e56d3375f1c54c7dd1b731054")
else()
set(REQ_URL "https://github.com/ProjectQ-Framework/ProjectQ/archive/v0.5.1.tar.gz ")
set(MD5 "13430199c253284df8b3d840f11d3560")
endif()
if(ENABLE_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux"
AND ${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message("Include projectq simulator")
mindspore_add_pkg(projectq
VER 0.5.1
HEAD_ONLY ./
URL ${REQ_URL}
MD5 ${MD5}
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/projectq/projectq.patch001
)
include_directories(${projectq_INC})
else()
message("Quantum simulation only support x86_64 linux platform.")
endif()

View File

@ -48,6 +48,12 @@ if(ENABLE_CPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
endif()
if(ENABLE_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux"
AND ${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message("Include projectq")
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/projectq.cmake)
endif()
if(ENABLE_GPU)
if(ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)

View File

@ -192,7 +192,7 @@ set(SUB_COMP
frontend/operator
pipeline/jit
pipeline/pynative
common debug pybind_api utils vm profiler ps
common debug pybind_api utils vm profiler ps mindquantum
)
foreach(_comp ${SUB_COMP})
@ -333,13 +333,14 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore mindspore::pybind11_module)
target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore mindspore::pybind11_module)
target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
else ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
else()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf
mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
if(${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm)

View File

@ -37,6 +37,26 @@ if(ENABLE_CPU)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc")
endif()
file(GLOB_RECURSE QUANTUM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"cpu/quantum/*.cc"
)
foreach(qcc ${QUANTUM_SRC_LIST})
list(REMOVE_ITEM CPU_SRC_LIST ${qcc})
endforeach()
if(ENABLE_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux"
AND ${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message("compiled quantum kernel_compiler")
set_property(SOURCE ${QUANTUM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_MINDQUANTUM)
set_property(SOURCE ${QUANTUM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS INTRIN)
set_property(SOURCE ${QUANTUM_SRC_LIST} PROPERTY COMPILE_OPTIONS -fopenmp -march=native -ffast-math)
else()
message("not compiled quantum kernel_compiler")
set(QUANTUM_SRC_LIST "")
endif()
endif()
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
@ -76,4 +96,5 @@ endif()
set_property(SOURCE ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_KERNEL)
add_library(_mindspore_backend_kernel_compiler_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST})
add_library(_mindspore_backend_kernel_compiler_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST}
${GPU_SRC_LIST} ${D_SRC_LIST} ${QUANTUM_SRC_LIST})

View File

@ -0,0 +1,248 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/quantum/pqc_cpu_kernel.h"
#include <omp.h>
#include <utility>
#include <thread>
#include <memory>
#include <algorithm>
#include "utils/ms_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
struct ComputeParam {
float *encoder_data_cp{nullptr};
float *ansatz_data_cp{nullptr};
float *output_cp{nullptr};
float *gradient_encoder_cp{nullptr};
float *gradient_ansatz_cp{nullptr};
mindquantum::BasicCircuit *circ_cp;
mindquantum::BasicCircuit *herm_circ_cp;
mindquantum::transformer::Hamiltonians *hams_cp;
mindquantum::transformer::NamesType *encoder_params_names_cp;
mindquantum::transformer::NamesType *ansatz_params_names_cp;
std::vector<std::vector<std::shared_ptr<mindquantum::PQCSimulator>>> *tmp_sims_cp;
bool dummy_circuit_cp{false};
size_t result_len_cp{0};
size_t encoder_g_len_cp{0};
size_t ansatz_g_len_cp{0};
};
void ComputerForwardBackward(const std::shared_ptr<ComputeParam> &input_params, size_t start, size_t end, size_t id) {
MS_EXCEPTION_IF_NULL(input_params);
MS_EXCEPTION_IF_NULL(input_params->encoder_data_cp);
MS_EXCEPTION_IF_NULL(input_params->ansatz_data_cp);
MS_EXCEPTION_IF_NULL(input_params->output_cp);
MS_EXCEPTION_IF_NULL(input_params->gradient_encoder_cp);
MS_EXCEPTION_IF_NULL(input_params->gradient_ansatz_cp);
auto encoder_data = input_params->encoder_data_cp;
auto ansatz_data = input_params->ansatz_data_cp;
auto output = input_params->output_cp;
auto gradient_encoder = input_params->gradient_encoder_cp;
auto gradient_ansatz = input_params->gradient_ansatz_cp;
auto circ = input_params->circ_cp;
auto herm_circ = input_params->herm_circ_cp;
auto hams = input_params->hams_cp;
auto encoder_params_names = input_params->encoder_params_names_cp;
auto ansatz_params_names = input_params->ansatz_params_names_cp;
auto tmp_sims = input_params->tmp_sims_cp;
auto dummy_circuit = input_params->dummy_circuit_cp;
auto result_len = input_params->result_len_cp;
auto encoder_g_len = input_params->encoder_g_len_cp;
auto ansatz_g_len = input_params->ansatz_g_len_cp;
MS_EXCEPTION_IF_NULL(hams);
MS_EXCEPTION_IF_NULL(encoder_params_names);
MS_EXCEPTION_IF_NULL(ansatz_params_names);
MS_EXCEPTION_IF_NULL(tmp_sims);
if (end * hams->size() > result_len || end * encoder_params_names->size() * hams->size() > encoder_g_len ||
end * ansatz_params_names->size() * hams->size() > ansatz_g_len) {
MS_LOG(EXCEPTION) << "pqc error input size!";
}
mindquantum::ParameterResolver pr;
for (size_t i = 0; i < ansatz_params_names->size(); i++) {
pr.SetData(ansatz_params_names->at(i), ansatz_data[i]);
}
for (size_t n = start; n < end; ++n) {
for (size_t i = 0; i < encoder_params_names->size(); i++) {
pr.SetData(encoder_params_names->at(i), encoder_data[n * encoder_params_names->size() + i]);
}
auto sim = tmp_sims->at(id)[3];
sim->SetZeroState();
sim->Evolution(*circ, pr);
auto calc_gradient_param = std::make_shared<mindquantum::CalcGradientParam>();
calc_gradient_param->circuit_cp = circ;
calc_gradient_param->circuit_hermitian_cp = herm_circ;
calc_gradient_param->hamiltonians_cp = hams;
calc_gradient_param->paras_cp = &pr;
calc_gradient_param->encoder_params_names_cp = encoder_params_names;
calc_gradient_param->ansatz_params_names_cp = ansatz_params_names;
calc_gradient_param->dummy_circuit_cp = dummy_circuit;
auto e0_grad1_grad_2 =
sim->CalcGradient(calc_gradient_param, *tmp_sims->at(id)[0], *tmp_sims->at(id)[1], *tmp_sims->at(id)[2]);
auto energy = e0_grad1_grad_2[0];
auto grad_encoder = e0_grad1_grad_2[1];
auto grad_ansatz = e0_grad1_grad_2[2];
if (energy.size() != hams->size() || grad_encoder.size() != encoder_params_names->size() * hams->size() ||
grad_ansatz.size() != ansatz_params_names->size() * hams->size()) {
MS_LOG(EXCEPTION) << "pqc error evolution or batch size!";
}
for (size_t poi = 0; poi < hams->size(); poi++) {
output[n * hams->size() + poi] = energy[poi];
}
for (size_t poi = 0; poi < encoder_params_names->size() * hams->size(); poi++) {
gradient_encoder[n * hams->size() * encoder_params_names->size() + poi] = grad_encoder[poi];
}
for (size_t poi = 0; poi < ansatz_params_names->size() * hams->size(); poi++) {
gradient_ansatz[n * hams->size() * ansatz_params_names->size() + poi] = grad_ansatz[poi];
}
}
}
} // namespace
void PQCCPUKernel::InitPQCStructure(const CNodePtr &kernel_node) {
n_threads_user_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, mindquantum::kNThreads);
n_qubits_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, mindquantum::kNQubits);
encoder_params_names_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::NamesType>(kernel_node, mindquantum::kEncoderParamsNames);
ansatz_params_names_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::NamesType>(kernel_node, mindquantum::kAnsatzParamsNames);
gate_names_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::NamesType>(kernel_node, mindquantum::kGateNames);
gate_matrix_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::ComplexMatrixsType>(kernel_node, mindquantum::kGateMatrix);
gate_obj_qubits_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::Indexess>(kernel_node, mindquantum::kGateObjQubits);
gate_ctrl_qubits_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::Indexess>(kernel_node, mindquantum::kGateCtrlQubits);
gate_params_names_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::ParasNameType>(kernel_node, mindquantum::kGateParamsNames);
gate_coeff_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::CoeffsType>(kernel_node, mindquantum::kGateCoeff);
gate_requires_grad_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::RequiresType>(kernel_node, mindquantum::kGateRequiresGrad);
hams_pauli_coeff_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisCoeffsType>(kernel_node, mindquantum::kHamsPauliCoeff);
hams_pauli_word_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisWordsType>(kernel_node, mindquantum::kHamsPauliWord);
hams_pauli_qubit_ =
AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisQubitsType>(kernel_node, mindquantum::kHamsPauliQubit);
}
void PQCCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> encoder_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> ansatz_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> result_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
std::vector<size_t> encoder_g_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 1);
std::vector<size_t> ansatz_g_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 2);
if (encoder_shape.size() != 2 || ansatz_shape.size() != 1 || result_shape.size() != 2 ||
encoder_g_shape.size() != 3 || ansatz_g_shape.size() != 3) {
MS_LOG(EXCEPTION) << "pqc invalid input size";
}
result_len_ = result_shape[0] * result_shape[1];
encoder_g_len_ = encoder_g_shape[0] * encoder_g_shape[1] * encoder_g_shape[2];
ansatz_g_len_ = ansatz_g_shape[0] * ansatz_g_shape[1] * ansatz_g_shape[2];
n_samples_ = static_cast<unsigned>(encoder_shape[0]);
InitPQCStructure(kernel_node);
dummy_circuit_ = !std::any_of(gate_requires_grad_.begin(), gate_requires_grad_.end(),
[](const mindquantum::transformer::RequireType &rr) {
return std::any_of(rr.begin(), rr.end(), [](const bool &r) { return r; });
});
auto circs = mindquantum::transformer::CircuitTransfor(gate_names_, gate_matrix_, gate_obj_qubits_, gate_ctrl_qubits_,
gate_params_names_, gate_coeff_, gate_requires_grad_);
circ_ = circs[0];
herm_circ_ = circs[1];
hams_ = mindquantum::transformer::HamiltoniansTransfor(hams_pauli_coeff_, hams_pauli_word_, hams_pauli_qubit_);
n_threads_user_ = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (n_samples_ < n_threads_user_) {
n_threads_user_ = n_samples_;
}
for (size_t i = 0; i < n_threads_user_; i++) {
tmp_sims_.push_back({});
for (size_t j = 0; j < 4; j++) {
auto tmp = std::make_shared<mindquantum::PQCSimulator>(1, n_qubits_);
tmp_sims_.back().push_back(tmp);
}
}
}
bool PQCCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != 2 || outputs.size() != 3) {
MS_LOG(EXCEPTION) << "pqc error input output size!";
}
auto encoder_data = reinterpret_cast<float *>(inputs[0]->addr);
auto ansatz_data = reinterpret_cast<float *>(inputs[1]->addr);
auto output = reinterpret_cast<float *>(outputs[0]->addr);
auto gradient_encoder = reinterpret_cast<float *>(outputs[1]->addr);
auto gradient_ansatz = reinterpret_cast<float *>(outputs[2]->addr);
MS_EXCEPTION_IF_NULL(encoder_data);
MS_EXCEPTION_IF_NULL(ansatz_data);
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(gradient_encoder);
MS_EXCEPTION_IF_NULL(gradient_ansatz);
std::vector<common::Task> tasks;
std::vector<std::shared_ptr<ComputeParam>> thread_params;
tasks.reserve(n_threads_user_);
size_t end = 0;
size_t offset = n_samples_ / n_threads_user_;
size_t left = n_samples_ % n_threads_user_;
for (size_t i = 0; i < n_threads_user_; ++i) {
auto params = std::make_shared<ComputeParam>();
params->encoder_data_cp = encoder_data;
params->ansatz_data_cp = ansatz_data;
params->output_cp = output;
params->gradient_encoder_cp = gradient_encoder;
params->gradient_ansatz_cp = gradient_ansatz;
params->circ_cp = &circ_;
params->herm_circ_cp = &herm_circ_;
params->hams_cp = &hams_;
params->encoder_params_names_cp = &encoder_params_names_;
params->ansatz_params_names_cp = &ansatz_params_names_;
params->tmp_sims_cp = &tmp_sims_;
params->dummy_circuit_cp = dummy_circuit_;
params->result_len_cp = result_len_;
params->encoder_g_len_cp = encoder_g_len_;
params->ansatz_g_len_cp = ansatz_g_len_;
size_t start = end;
end = start + offset;
if (i < left) {
end += 1;
}
auto task = [&params, start, end, i]() {
ComputerForwardBackward(params, start, end, i);
return common::SUCCESS;
};
tasks.emplace_back(task);
thread_params.emplace_back(params);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PQC_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PQC_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "mindquantum/pqc_simulator.h"
#include "mindquantum/transformer.h"
#include "mindquantum/circuit.h"
#include "mindquantum/parameter_resolver.h"
namespace mindspore {
namespace kernel {
class PQCCPUKernel : public CPUKernel {
public:
PQCCPUKernel() = default;
~PQCCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
void InitPQCStructure(const CNodePtr &kernel_node);
private:
size_t n_samples_;
size_t n_threads_user_;
bool dummy_circuit_;
size_t result_len_;
size_t encoder_g_len_;
size_t ansatz_g_len_;
int64_t n_qubits_;
mindquantum::BasicCircuit circ_;
mindquantum::BasicCircuit herm_circ_;
mindquantum::transformer::Hamiltonians hams_;
std::vector<std::vector<std::shared_ptr<mindquantum::PQCSimulator>>> tmp_sims_;
// parameters
mindquantum::transformer::NamesType encoder_params_names_;
mindquantum::transformer::NamesType ansatz_params_names_;
// quantum circuit
mindquantum::transformer::NamesType gate_names_;
mindquantum::transformer::ComplexMatrixsType gate_matrix_;
mindquantum::transformer::Indexess gate_obj_qubits_;
mindquantum::transformer::Indexess gate_ctrl_qubits_;
mindquantum::transformer::ParasNameType gate_params_names_;
mindquantum::transformer::CoeffsType gate_coeff_;
mindquantum::transformer::RequiresType gate_requires_grad_;
// hamiltonian
mindquantum::transformer::PaulisCoeffsType hams_pauli_coeff_;
mindquantum::transformer::PaulisWordsType hams_pauli_word_;
mindquantum::transformer::PaulisQubitsType hams_pauli_qubit_;
};
MS_REG_CPU_KERNEL(PQC,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PQCCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PQC_CPU_KERNEL_H_

View File

@ -26,31 +26,32 @@ namespace mindspore {
namespace {
static const char *GetSubModuleName(SubModuleId module_id) {
static const char *sub_module_names[NUM_SUBMODUES] = {
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT" // SM_HCCL_ADPT
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT", // SM_HCCL_ADPT
"MINDQUANTUM" // SM_MINDQUANTUM
};
return sub_module_names[module_id % NUM_SUBMODUES];

View File

@ -0,0 +1,10 @@
if(ENABLE_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux"
AND ${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message("compiled mindquantum")
file(GLOB_RECURSE _MINDQUANTUM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_MINDQUANTUM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_MINDQUANTUM)
add_library(_mindspore_mindquantum_obj OBJECT ${_MINDQUANTUM_SRC_LIST})
target_compile_options(_mindspore_mindquantum_obj PRIVATE -fopenmp -march=native -ffast-math)
target_compile_definitions(_mindspore_mindquantum_obj PRIVATE INTRIN)
endif()

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/circuit.h"
namespace mindspore {
namespace mindquantum {
BasicCircuit::BasicCircuit() : gate_blocks_({}) {}
void BasicCircuit::AppendBlock() { gate_blocks_.push_back({}); }
void BasicCircuit::AppendNoneParameterGate(const std::string &name, Matrix m, Indexes obj_qubits, Indexes ctrl_qubits) {
auto npg = std::make_shared<NoneParameterGate>(name, m, obj_qubits, ctrl_qubits);
gate_blocks_.back().push_back(npg);
}
void BasicCircuit::AppendParameterGate(const std::string &name, Indexes obj_qubits, Indexes ctrl_qubits,
const ParameterResolver &paras) {
if (name == "RX") {
auto pg_rx = std::make_shared<RXGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_rx);
} else if (name == "RY") {
auto pg_ry = std::make_shared<RYGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_ry);
} else if (name == "RZ") {
auto pg_rz = std::make_shared<RZGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_rz);
} else if (name == "XX") {
auto pg_xx = std::make_shared<XXGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_xx);
} else if (name == "YY") {
auto pg_yy = std::make_shared<YYGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_yy);
} else if (name == "ZZ") {
auto pg_zz = std::make_shared<ZZGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_zz);
} else if (name == "PS") {
auto pg_ps = std::make_shared<PhaseShiftGate>(obj_qubits, ctrl_qubits, paras);
gate_blocks_.back().push_back(pg_ps);
} else {
}
}
const GateBlocks &BasicCircuit::GetGateBlocks() const { return gate_blocks_; }
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_CCIRCUIT_H_
#define MINDQUANTUM_ENGINE_CCIRCUIT_H_
#include <vector>
#include <string>
#include <memory>
#include "mindquantum/gates/non_parameter_gate.h"
#include "mindquantum/gates/gates.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
using GateBlock = std::vector<std::shared_ptr<BasicGate>>;
using GateBlocks = std::vector<GateBlock>;
class BasicCircuit {
private:
GateBlocks gate_blocks_;
public:
BasicCircuit();
void AppendBlock();
void AppendNoneParameterGate(const std::string &, Matrix, Indexes, Indexes);
void AppendParameterGate(const std::string &, Indexes, Indexes, const ParameterResolver &);
const GateBlocks &GetGateBlocks() const;
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_CCIRCUIT_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/gates/basic_gates.h"
#include <string>
namespace mindspore {
namespace mindquantum {
BasicGate::BasicGate(const std::string &name, bool is_parameter, const Indexes &obj_qubits, const Indexes &ctrl_qubits,
const ParameterResolver &paras)
: name_(name), is_parameter_(is_parameter), obj_qubits_(obj_qubits), ctrl_qubits_(ctrl_qubits), paras_(paras) {}
Matrix BasicGate::GetMatrix(const ParameterResolver &paras_out) {
Matrix gate_matrix_tmp;
return gate_matrix_tmp;
}
Matrix BasicGate::GetDiffMatrix(const ParameterResolver &paras_out) {
Matrix gate_matrix_tmp;
return gate_matrix_tmp;
}
Matrix &BasicGate::GetBaseMatrix() { return gate_matrix_base_; }
const ParameterResolver &BasicGate::GetParameterResolver() const { return paras_; }
bool BasicGate::IsParameterGate() { return is_parameter_; }
Indexes BasicGate::GetObjQubits() { return obj_qubits_; }
Indexes BasicGate::GetCtrlQubits() { return ctrl_qubits_; }
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_BASIC_GATES_H_
#define MINDQUANTUM_ENGINE_BASIC_GATES_H_
#include <string>
#include "mindquantum/parameter_resolver.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class BasicGate {
private:
std::string name_;
bool is_parameter_;
Matrix gate_matrix_base_;
Indexes obj_qubits_;
Indexes ctrl_qubits_;
ParameterResolver paras_;
public:
BasicGate();
BasicGate(const std::string &, bool, const Indexes &, const Indexes &,
const ParameterResolver &paras = ParameterResolver());
virtual Matrix GetMatrix(const ParameterResolver &);
virtual Matrix GetDiffMatrix(const ParameterResolver &);
virtual Matrix &GetBaseMatrix();
const ParameterResolver &GetParameterResolver() const;
bool IsParameterGate();
Indexes GetObjQubits();
Indexes GetCtrlQubits();
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_BASIC_GATES_H_

View File

@ -0,0 +1,154 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/gates/gates.h"
#include <cmath>
namespace mindspore {
namespace mindquantum {
RXGate::RXGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("RX", obj_qubits, ctrl_qubits, paras) {}
RXGate::RXGate() : IntrinsicOneParaGate("RX", {}, {}, {}) {}
Matrix RXGate::GetIntrinsicMatrix(CalcType theta) {
Matrix result = {{{cos(theta / 2), 0}, {0, -sin(theta / 2)}}, {{0, -sin(theta / 2)}, {cos(theta / 2), 0}}};
return result;
}
Matrix RXGate::GetIntrinsicDiffMatrix(CalcType theta) {
Matrix result = {{{-sin(theta / 2) / 2, 0}, {0, -cos(theta / 2) / 2}},
{{0, -cos(theta / 2) / 2}, {-sin(theta / 2) / 2, 0}}};
return result;
}
RYGate::RYGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("RY", obj_qubits, ctrl_qubits, paras) {}
Matrix RYGate::GetIntrinsicMatrix(CalcType theta) {
Matrix result = {{{cos(theta / 2), 0}, {-sin(theta / 2), 0}}, {{sin(theta / 2), 0}, {cos(theta / 2), 0}}};
return result;
}
Matrix RYGate::GetIntrinsicDiffMatrix(CalcType theta) {
Matrix result = {{{-sin(theta / 2) / 2, 0}, {-cos(theta / 2) / 2, 0}},
{{cos(theta / 2) / 2, 0}, {-sin(theta / 2) / 2, 0}}};
return result;
}
RZGate::RZGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("RZ", obj_qubits, ctrl_qubits, paras) {}
Matrix RZGate::GetIntrinsicMatrix(CalcType theta) {
Matrix result = {{{cos(theta / 2), -sin(theta / 2)}, {0, 0}}, {{0, 0}, {cos(theta / 2), sin(theta / 2)}}};
return result;
}
Matrix RZGate::GetIntrinsicDiffMatrix(CalcType theta) {
Matrix result = {{{-sin(theta / 2) / 2, -cos(theta / 2) / 2}, {0, 0}},
{{0, 0}, {-sin(theta / 2) / 2, cos(theta / 2) / 2}}};
return result;
}
PhaseShiftGate::PhaseShiftGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("PS", obj_qubits, ctrl_qubits, paras) {}
Matrix PhaseShiftGate::GetIntrinsicMatrix(CalcType theta) {
Matrix result = {{{1, 0}, {0, 0}}, {{0, 0}, {cos(theta), sin(theta)}}};
return result;
}
Matrix PhaseShiftGate::GetIntrinsicDiffMatrix(CalcType theta) {
Matrix result = {{{0, 0}, {0, 0}}, {{0, 0}, {-sin(theta), cos(theta)}}};
return result;
}
XXGate::XXGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("XX", obj_qubits, ctrl_qubits, paras) {}
Matrix XXGate::GetIntrinsicMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{c, 0}, {0, 0}, {0, 0}, {0, -s}},
{{0, 0}, {c, 0}, {0, -s}, {0, 0}},
{{0, 0}, {0, -s}, {c, 0}, {0, 0}},
{{0, -s}, {0, 0}, {0, 0}, {c, 0}}};
return result;
}
Matrix XXGate::GetIntrinsicDiffMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{-s, 0}, {0, 0}, {0, 0}, {0, -c}},
{{0, 0}, {-s, 0}, {0, -c}, {0, 0}},
{{0, 0}, {0, -c}, {-s, 0}, {0, 0}},
{{0, -c}, {0, 0}, {0, 0}, {-s, 0}}};
return result;
}
YYGate::YYGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("YY", obj_qubits, ctrl_qubits, paras) {}
Matrix YYGate::GetIntrinsicMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{c, 0}, {0, 0}, {0, 0}, {0, s}},
{{0, 0}, {c, 0}, {0, -s}, {0, 0}},
{{0, 0}, {0, -s}, {c, 0}, {0, 0}},
{{0, s}, {0, 0}, {0, 0}, {c, 0}}};
return result;
}
Matrix YYGate::GetIntrinsicDiffMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{-s, 0}, {0, 0}, {0, 0}, {0, c}},
{{0, 0}, {-s, 0}, {0, -c}, {0, 0}},
{{0, 0}, {0, -c}, {-s, 0}, {0, 0}},
{{0, c}, {0, 0}, {0, 0}, {-s, 0}}};
return result;
}
ZZGate::ZZGate(const Indexes &obj_qubits, const Indexes &ctrl_qubits, const ParameterResolver &paras)
: IntrinsicOneParaGate("ZZ", obj_qubits, ctrl_qubits, paras) {}
Matrix ZZGate::GetIntrinsicMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{c, -s}, {0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {c, s}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {c, s}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}, {c, -s}}};
return result;
}
Matrix ZZGate::GetIntrinsicDiffMatrix(CalcType theta) {
double c = cos(theta);
double s = sin(theta);
Matrix result = {{{-s, -c}, {0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {-s, c}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {-s, c}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}, {-s, -c}}};
return result;
}
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_GATES_H_
#define MINDQUANTUM_ENGINE_GATES_H_
#include "mindquantum/gates/intrinsic_one_para_gate.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class RXGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
RXGate(const Indexes &, const Indexes &, const ParameterResolver &);
RXGate();
};
class RYGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
RYGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
class RZGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
RZGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
class PhaseShiftGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
PhaseShiftGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
class XXGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
XXGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
class YYGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
YYGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
class ZZGate : public IntrinsicOneParaGate {
Matrix GetIntrinsicMatrix(CalcType) override;
Matrix GetIntrinsicDiffMatrix(CalcType) override;
public:
ZZGate(const Indexes &, const Indexes &, const ParameterResolver &);
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_GATES_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/gates/intrinsic_one_para_gate.h"
#include <string>
namespace mindspore {
namespace mindquantum {
Matrix IntrinsicOneParaGate::GetIntrinsicMatrix(CalcType theta) {
Matrix gate_matrix_tmp;
return gate_matrix_tmp;
}
Matrix IntrinsicOneParaGate::GetIntrinsicDiffMatrix(CalcType theta) {
Matrix gate_matrix_tmp;
return gate_matrix_tmp;
}
IntrinsicOneParaGate::IntrinsicOneParaGate(const std::string &name, const Indexes &obj_qubits,
const Indexes &ctrl_qubits, const ParameterResolver &paras)
: ParameterGate(name, obj_qubits, ctrl_qubits, paras) {}
CalcType IntrinsicOneParaGate::LinearCombination(const ParameterResolver &paras_in,
const ParameterResolver &paras_out) {
CalcType result = 0;
auto &paras_in_data = paras_in.GetData();
auto &paras_out_data = paras_out.GetData();
for (ParaType::const_iterator i = paras_in_data.begin(); i != paras_in_data.end(); ++i) {
result = result + paras_out_data.at(i->first) * (i->second);
}
return result;
}
Matrix IntrinsicOneParaGate::GetMatrix(const ParameterResolver &paras_out) {
return GetIntrinsicMatrix(LinearCombination(GetParameterResolver(), paras_out));
}
Matrix IntrinsicOneParaGate::GetDiffMatrix(const ParameterResolver &paras_out) {
return GetIntrinsicDiffMatrix(LinearCombination(GetParameterResolver(), paras_out));
}
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_INTRINSIC_ONE_PARAGATE_H_
#define MINDQUANTUM_ENGINE_INTRINSIC_ONE_PARAGATE_H_
#include <string>
#include "mindquantum/gates/parameter_gate.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class IntrinsicOneParaGate : public ParameterGate {
virtual Matrix GetIntrinsicMatrix(CalcType);
virtual Matrix GetIntrinsicDiffMatrix(CalcType);
public:
IntrinsicOneParaGate();
IntrinsicOneParaGate(const std::string &, const Indexes &, const Indexes &, const ParameterResolver &);
CalcType LinearCombination(const ParameterResolver &, const ParameterResolver &);
Matrix GetMatrix(const ParameterResolver &) override;
Matrix GetDiffMatrix(const ParameterResolver &) override;
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_INTRINSIC_ONE_PARAGATE_H_

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/gates/non_parameter_gate.h"
#include <string>
namespace mindspore {
namespace mindquantum {
NoneParameterGate::NoneParameterGate(const std::string &name, const Matrix &gate_matrix, const Indexes &obj_qubits,
const Indexes &ctrl_qubits)
: BasicGate(name, false, obj_qubits, ctrl_qubits), gate_matrix_(gate_matrix) {}
Matrix &NoneParameterGate::GetBaseMatrix() { return gate_matrix_; }
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_NON_PARAMETER_GATE_H_
#define MINDQUANTUM_ENGINE_NON_PARAMETER_GATE_H_
#include <string>
#include "mindquantum/gates/basic_gates.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class NoneParameterGate : public BasicGate {
private:
Matrix gate_matrix_;
public:
NoneParameterGate(const std::string &, const Matrix &, const Indexes &, const Indexes &);
Matrix &GetBaseMatrix() override;
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_NON_PARAMETER_GATE_H_

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/gates/parameter_gate.h"
#include <string>
namespace mindspore {
namespace mindquantum {
ParameterGate::ParameterGate(const std::string &name, const Indexes &obj_qubits, const Indexes &ctrl_qubits,
const ParameterResolver &paras)
: BasicGate(name, true, obj_qubits, ctrl_qubits, paras) {}
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_PARAMETER_GATE_H_
#define MINDQUANTUM_ENGINE_PARAMETER_GATE_H_
#include <string>
#include "mindquantum/gates/basic_gates.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class ParameterGate : public BasicGate {
public:
ParameterGate();
ParameterGate(const std::string &, const Indexes &, const Indexes &, const ParameterResolver &);
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_PARAMETER_GATE_H_

View File

@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/hamiltonian.h"
#include <utility>
namespace mindspore {
namespace mindquantum {
Hamiltonian::Hamiltonian() {}
Hamiltonian::Hamiltonian(const sparse::GoodHamilt &ham, Index n) : ham_(ham), n_qubits_(n) {}
sparse::DequeSparseHam Hamiltonian::TransHamiltonianPhaseOne(int n_thread1, const sparse::GoodHamilt &ham, Index n) {
sparse::DequeSparseHam ham_sparse;
ham_sparse.resize(ham.size());
int step = 0;
#pragma omp parallel for schedule(static) num_threads(n_thread1)
for (Index i = 0; i < ham.size(); i++) {
auto &gt = ham.at(i);
if (gt.second[0].first.size() == 0) {
ham_sparse[i] = sparse::IdentitySparse(n) * gt.first.first * gt.second[0].second;
} else {
ham_sparse[i] = sparse::GoodTerm2Sparse(gt, n);
}
if ((++step) % 20 == 0) std::cout << "\r" << step << "\t/" << ham.size() << "\tfinshed" << std::flush;
}
std::cout << "\ncalculate hamiltonian phase1 finished\n";
return ham_sparse;
}
int Hamiltonian::TransHamiltonianPhaseTwo(sparse::DequeSparseHam &ham_sparse, int n_thread2, int n_split) {
int n = ham_sparse.size();
while (n > 1) {
int half = n / 2 + n % 2;
std::cout << "n: " << n << "\t, half: " << half << "\n";
if (n < n_split) {
break;
}
#pragma omp parallel for schedule(static) num_threads(half)
for (int i = half; i < n; i++) {
ham_sparse[i - half] += ham_sparse[i];
}
ham_sparse.erase(ham_sparse.end() - half + n % 2, ham_sparse.end());
n = half;
}
std::cout << "total: " << ham_sparse.size() << " phase2 finished\n";
return n;
}
void Hamiltonian::SparseHamiltonian(int n_thread1, int n_thread2, int n_split) {
ham_sparse_ = Hamiltonian::TransHamiltonianPhaseOne(n_thread1, ham_, n_qubits_);
final_size_ = Hamiltonian::TransHamiltonianPhaseTwo(ham_sparse_, n_thread2, n_split);
}
void Hamiltonian::SetTermsDict(Simulator::TermsDict const &d) {
td_ = d;
Simulator::ComplexTermsDict().swap(ctd_);
for (auto &term : td_) {
ComplexType coeff = {term.second, 0};
ctd_.push_back(std::make_pair(term.first, coeff));
}
}
void Hamiltonian::Sparsed(bool s) { ham_sparsed_ = s; }
const Simulator::ComplexTermsDict &Hamiltonian::GetCTD() const { return ctd_; }
const Simulator::TermsDict &Hamiltonian::GetTD() const { return td_; }
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_CHAMILTONIAN_H_
#define MINDQUANTUM_ENGINE_CHAMILTONIAN_H_
#include "projectq/backends/_sim/_cppkernels/simulator.hpp"
#include "mindquantum/gates/basic_gates.h"
#include "mindquantum/sparse.h"
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class Hamiltonian {
private:
sparse::GoodHamilt ham_;
Index n_qubits_;
sparse::DequeSparseHam ham_sparse_;
Simulator::TermsDict td_;
Simulator::ComplexTermsDict ctd_;
int final_size_ = 1;
bool ham_sparsed_ = false;
public:
Hamiltonian();
Hamiltonian(const sparse::GoodHamilt &, Index);
sparse::DequeSparseHam TransHamiltonianPhaseOne(int, const sparse::GoodHamilt &, Index);
int TransHamiltonianPhaseTwo(sparse::DequeSparseHam &, int, int);
void SparseHamiltonian(int, int, int);
void SetTermsDict(Simulator::TermsDict const &);
void Sparsed(bool);
const Simulator::ComplexTermsDict &GetCTD() const;
const Simulator::TermsDict &GetTD() const;
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_CHAMILTONIAN_H_

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/parameter_resolver.h"
namespace mindspore {
namespace mindquantum {
ParameterResolver::ParameterResolver()
: data_(ParaType()), no_grad_parameters_(ParaSetType()), requires_grad_parameters_(ParaSetType()) {}
ParameterResolver::ParameterResolver(const ParaType &data, const ParaSetType &no_grad_parameters,
const ParaSetType &requires_grad_parameters)
: data_(data), no_grad_parameters_(no_grad_parameters), requires_grad_parameters_(requires_grad_parameters) {}
const ParaType &ParameterResolver::GetData() const { return data_; }
const ParaSetType &ParameterResolver::GetRequiresGradParameters() const { return requires_grad_parameters_; }
void ParameterResolver::SetData(const std::string &name, const CalcType &value) { data_[name] = value; }
void ParameterResolver::InsertNoGrad(const std::string &name) { no_grad_parameters_.insert(name); }
void ParameterResolver::InsertRequiresGrad(const std::string &name) { requires_grad_parameters_.insert(name); }
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_PARAMETER_RESOLVER_H_
#define MINDQUANTUM_ENGINE_PARAMETER_RESOLVER_H_
#include <map>
#include <string>
#include <set>
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
class ParameterResolver {
public:
ParameterResolver();
ParameterResolver(const ParaType &, const ParaSetType &, const ParaSetType &);
const ParaType &GetData() const;
const ParaSetType &GetRequiresGradParameters() const;
void SetData(const std::string &, const CalcType &);
void InsertNoGrad(const std::string &);
void InsertRequiresGrad(const std::string &);
private:
ParaType data_;
ParaSetType no_grad_parameters_;
ParaSetType requires_grad_parameters_;
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_PARAMETER_RESOLVER_H_

View File

@ -0,0 +1,192 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/pqc_simulator.h"
#include <omp.h>
#include <numeric>
namespace mindspore {
namespace mindquantum {
PQCSimulator::PQCSimulator() : Simulator(1), n_qubits_(1) {
PQCSimulator::AllocateAll();
for (Index i = 0; i < n_qubits_; i++) {
ordering_.push_back(i);
}
len_ = (1UL << n_qubits_);
}
PQCSimulator::PQCSimulator(Index seed = 1, Index N = 1) : Simulator(seed), n_qubits_(N) {
PQCSimulator::AllocateAll();
for (Index i = 0; i < n_qubits_; i++) {
ordering_.push_back(i);
}
len_ = (1UL << n_qubits_);
}
void PQCSimulator::ApplyGate(std::shared_ptr<BasicGate> g, const ParameterResolver &paras, bool diff) {
if (g->IsParameterGate()) {
if (diff) {
PQCSimulator::apply_controlled_gate(g->GetDiffMatrix(paras), g->GetObjQubits(), g->GetCtrlQubits());
} else {
PQCSimulator::apply_controlled_gate(g->GetMatrix(paras), g->GetObjQubits(), g->GetCtrlQubits());
}
} else {
PQCSimulator::apply_controlled_gate(g->GetBaseMatrix(), g->GetObjQubits(), g->GetCtrlQubits());
}
}
void PQCSimulator::ApplyBlock(const GateBlock &b, const mindquantum::ParameterResolver &paras) {
for (auto &g : b) {
PQCSimulator::ApplyGate(g, paras, false);
}
PQCSimulator::run();
}
void PQCSimulator::ApplyBlocks(const GateBlocks &bs, const ParameterResolver &paras) {
for (auto &b : bs) {
PQCSimulator::ApplyBlock(b, paras);
}
}
void PQCSimulator::Evolution(BasicCircuit const &circuit, ParameterResolver const &paras) {
PQCSimulator::ApplyBlocks(circuit.GetGateBlocks(), paras);
}
CalcType PQCSimulator::Measure(Index mask1, Index mask2, bool apply) {
CalcType out = 0;
#pragma omp parallel for reduction(+ : out) schedule(static)
for (unsigned i = 0; i < (1UL << n_qubits_); i++) {
if (((i & mask1) == mask1) && ((i | mask2) == mask2)) {
out = out + std::real(vec_[i]) * std::real(vec_[i]) + std::imag(vec_[i]) * std::imag(vec_[i]);
} else if (apply) {
vec_[i] = 0;
}
}
return out;
}
std::vector<std::vector<float>> PQCSimulator::CalcGradient(const std::shared_ptr<CalcGradientParam> &input_params,
PQCSimulator &s_left, PQCSimulator &s_right,
PQCSimulator &s_right_tmp) {
// Suppose the simulator already evaluate the circuit.
auto circuit = input_params->circuit_cp;
auto circuit_hermitian = input_params->circuit_hermitian_cp;
auto hamiltonians = input_params->hamiltonians_cp;
auto paras = input_params->paras_cp;
auto encoder_params_names = input_params->encoder_params_names_cp;
auto ansatz_params_names = input_params->ansatz_params_names_cp;
auto dummy_circuit_ = input_params->dummy_circuit_cp;
auto &circ_gate_blocks = circuit->GetGateBlocks();
auto &circ_herm_gate_blocks = circuit_hermitian->GetGateBlocks();
std::map<std::string, size_t> poi;
for (size_t i = 0; i < encoder_params_names->size(); i++) {
poi[encoder_params_names->at(i)] = i;
}
for (size_t i = 0; i < ansatz_params_names->size(); i++) {
poi[ansatz_params_names->at(i)] = i + encoder_params_names->size();
}
if (circ_gate_blocks.size() == 0 || circ_herm_gate_blocks.size() == 0) {
MS_LOG(EXCEPTION) << "Empty quantum circuit!";
}
unsigned len = circ_gate_blocks.at(0).size();
std::vector<float> grad(hamiltonians->size() * poi.size(), 0);
std::vector<float> e0(hamiltonians->size(), 0);
// #pragma omp parallel for
for (size_t h_index = 0; h_index < hamiltonians->size(); h_index++) {
auto &hamiltonian = hamiltonians->at(h_index);
s_right.set_wavefunction(vec_, ordering_);
s_left.set_wavefunction(s_right.vec_, ordering_);
s_left.apply_qubit_operator(hamiltonian.GetCTD(), ordering_);
e0[h_index] = static_cast<float>(ComplexInnerProduct(vec_, s_left.vec_, len_).real());
if (dummy_circuit_) {
continue;
}
for (unsigned i = 0; i < len; i++) {
if ((!circ_herm_gate_blocks.at(0)[i]->IsParameterGate()) ||
(circ_herm_gate_blocks.at(0)[i]->GetParameterResolver().GetRequiresGradParameters().size() == 0)) {
s_left.ApplyGate(circ_herm_gate_blocks.at(0)[i], *paras, false);
s_right.ApplyGate(circ_herm_gate_blocks.at(0)[i], *paras, false);
} else {
s_right.ApplyGate(circ_herm_gate_blocks.at(0)[i], *paras, false);
s_right.run();
s_right_tmp.set_wavefunction(s_right.vec_, ordering_);
s_right_tmp.ApplyGate(circ_gate_blocks.at(0)[len - 1 - i], *paras, true);
s_right_tmp.run();
s_left.run();
ComplexType gi = 0;
if (circ_herm_gate_blocks.at(0)[i]->GetCtrlQubits().size() == 0) {
gi = ComplexInnerProduct(s_left.vec_, s_right_tmp.vec_, len_);
} else {
gi = ComplexInnerProductWithControl(s_left.vec_, s_right_tmp.vec_, len_,
GetControlMask(circ_herm_gate_blocks.at(0)[i]->GetCtrlQubits()));
}
for (auto &it : circ_herm_gate_blocks.at(0)[i]->GetParameterResolver().GetRequiresGradParameters()) {
grad[h_index * poi.size() + poi[it]] -= static_cast<float>(
2 * circ_herm_gate_blocks.at(0)[i]->GetParameterResolver().GetData().at(it) * std::real(gi));
}
s_left.ApplyGate(circ_herm_gate_blocks.at(0)[i], *paras, false);
}
}
}
std::vector<float> grad1;
std::vector<float> grad2;
for (size_t i = 0; i < hamiltonians->size(); i++) {
for (size_t j = 0; j < poi.size(); j++) {
if (j < encoder_params_names->size()) {
grad1.push_back(grad[i * poi.size() + j]);
} else {
grad2.push_back(grad[i * poi.size() + j]);
}
}
}
return {e0, grad1, grad2};
}
void PQCSimulator::AllocateAll() {
for (unsigned i = 0; i < n_qubits_; i++) {
Simulator::allocate_qubit(i);
}
}
void PQCSimulator::DeallocateAll() {
for (unsigned i = 0; i < n_qubits_; i++) {
Simulator::deallocate_qubit(i);
}
}
void PQCSimulator::SetState(const StateVector &wavefunction) { Simulator::set_wavefunction(wavefunction, ordering_); }
std::size_t PQCSimulator::GetControlMask(Indexes const &ctrls) {
std::size_t ctrlmask =
std::accumulate(ctrls.begin(), ctrls.end(), 0, [&](Index a, Index b) { return a | (1UL << ordering_[b]); });
return ctrlmask;
}
void PQCSimulator::ApplyHamiltonian(const Hamiltonian &ham) {
Simulator::apply_qubit_operator(ham.GetCTD(), ordering_);
}
CalcType PQCSimulator::GetExpectationValue(const Hamiltonian &ham) {
return Simulator::get_expectation_value(ham.GetTD(), ordering_);
}
void PQCSimulator::SetZeroState() {
#pragma omp parallel for schedule(static)
for (size_t i = 0; i < len_; i++) {
vec_[i] = {0, 0};
}
vec_[0] = {1, 0};
}
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_PQC_SIMULATOR_H_
#define MINDQUANTUM_ENGINE_PQC_SIMULATOR_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "projectq/backends/_sim/_cppkernels/simulator.hpp"
#include "utils/log_adapter.h"
#include "mindquantum/gates/basic_gates.h"
#include "mindquantum/parameter_resolver.h"
#include "mindquantum/circuit.h"
#include "mindquantum/hamiltonian.h"
#include "mindquantum/utils.h"
#include "mindquantum/transformer.h"
namespace mindspore {
namespace mindquantum {
struct CalcGradientParam {
BasicCircuit *circuit_cp;
BasicCircuit *circuit_hermitian_cp;
transformer::Hamiltonians *hamiltonians_cp;
ParameterResolver *paras_cp;
transformer::NamesType *encoder_params_names_cp;
transformer::NamesType *ansatz_params_names_cp;
bool dummy_circuit_cp{false};
};
class PQCSimulator : public Simulator {
private:
Index n_qubits_;
Indexes ordering_;
Index len_;
public:
PQCSimulator();
PQCSimulator(Index seed, Index N);
void ApplyGate(std::shared_ptr<BasicGate>, const ParameterResolver &, bool);
void ApplyBlock(const GateBlock &, const ParameterResolver &);
void ApplyBlocks(const GateBlocks &, const ParameterResolver &);
void Evolution(const BasicCircuit &, const ParameterResolver &);
CalcType Measure(Index, Index, bool);
void ApplyHamiltonian(const Hamiltonian &);
CalcType GetExpectationValue(const Hamiltonian &);
std::vector<std::vector<float>> CalcGradient(const std::shared_ptr<CalcGradientParam> &, PQCSimulator &,
PQCSimulator &, PQCSimulator &);
void AllocateAll();
void DeallocateAll();
void SetState(const StateVector &);
std::size_t GetControlMask(Indexes const &);
void SetZeroState();
};
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_PQC_SIMULATOR_H_

View File

@ -0,0 +1,124 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/sparse.h"
namespace mindspore {
namespace mindquantum {
namespace sparse {
SparseMatrix BasiGateSparse(char g) {
SparseMatrix out(2, 2);
out.reserve(VectorXi::Constant(2, 2));
switch (g) {
case 'X':
case 'x':
out.insert(0, 1) = 1;
out.insert(1, 0) = 1;
break;
case 'Y':
case 'y':
out.insert(0, 1) = {0, -1};
out.insert(1, 0) = {0, 1};
break;
case 'Z':
case 'z':
out.insert(0, 0) = 1;
out.insert(1, 1) = -1;
break;
case '0':
out.insert(0, 0) = 1;
break;
case '1':
out.insert(1, 1) = 1;
break;
default:
out.insert(0, 0) = 1;
out.insert(1, 1) = 1;
break;
}
out.makeCompressed();
return out;
}
SparseMatrix IdentitySparse(int n_qubit) {
if (n_qubit == 0) {
int dim = 1UL << n_qubit;
SparseMatrix out(dim, dim);
out.reserve(VectorXi::Constant(dim, dim));
for (int i = 0; i < dim; i++) {
out.insert(i, i) = 1;
}
out.makeCompressed();
return out;
} else {
SparseMatrix out = BasiGateSparse('I');
for (int i = 1; i < n_qubit; i++) {
out = KroneckerProductSparse(out, BasiGateSparse('I')).eval();
}
return out;
}
}
SparseMatrix PauliTerm2Sparse(const PauliTerm &pt, Index _min, Index _max) {
int poi;
int n = pt.first.size();
SparseMatrix out;
if (pt.first[0].first == _min) {
out = BasiGateSparse(pt.first[0].second) * pt.second;
poi = 1;
} else {
out = BasiGateSparse('I') * pt.second;
poi = 0;
}
for (Index i = _min + 1; i <= _max; i++) {
if (poi == n) {
out = KroneckerProductSparse(IdentitySparse(_max - i + 1), out).eval();
break;
} else {
if (i == pt.first[poi].first) {
out = KroneckerProductSparse(BasiGateSparse(pt.first[poi++].second), out).eval();
} else {
out = KroneckerProductSparse(BasiGateSparse('I'), out).eval();
}
}
}
return out;
}
SparseMatrix GoodTerm2Sparse(const GoodTerm &gt, Index n_qubits) {
SparseMatrix out = PauliTerm2Sparse(gt.second[0], gt.first.second.first, gt.first.second.second);
for (Index i = 1; i < gt.second.size(); i++) {
out += PauliTerm2Sparse(gt.second[i], gt.first.second.first, gt.first.second.second);
}
out.prune({0.0, 0.0});
out *= gt.first.first;
out = KroneckerProductSparse(out, IdentitySparse(gt.first.second.first)).eval();
out = KroneckerProductSparse(IdentitySparse(n_qubits - gt.first.second.second - 1), out).eval();
return out;
}
} // namespace sparse
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_SPARSE_H_
#define MINDQUANTUM_ENGINE_SPARSE_H_
#include <Eigen/Dense>
#include <Eigen/Sparse>
#include <unsupported/Eigen/KroneckerProduct>
#include <deque>
#include <complex>
#include <utility>
#include <iostream>
#include <vector>
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
namespace sparse {
using PauliWord = std::pair<Index, char>;
using PauliTerm = std::pair<std::vector<PauliWord>, int>;
using GoodTerm = std::pair<std::pair<CalcType, std::pair<Index, Index>>, std::vector<PauliTerm>>;
using GoodHamilt = std::vector<GoodTerm>;
typedef Eigen::VectorXcd EigenComplexVector;
using Eigen::VectorXi;
typedef Eigen::SparseMatrix<ComplexType, Eigen::RowMajor, int64_t> SparseMatrix;
using DequeSparseHam = std::deque<SparseMatrix>;
using KroneckerProductSparse = Eigen::KroneckerProductSparse<SparseMatrix, SparseMatrix>;
SparseMatrix BasiGateSparse(char);
SparseMatrix IdentitySparse(int);
SparseMatrix PauliTerm2Sparse(const PauliTerm &, Index, Index);
SparseMatrix GoodTerm2Sparse(const GoodTerm &, Index);
} // namespace sparse
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_SPARSE_H_

View File

@ -0,0 +1,115 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/transformer.h"
#include <algorithm>
#include <utility>
namespace mindspore {
namespace mindquantum {
namespace transformer {
Matrix MatrixConverter(const MatrixType &matrix_real, const MatrixType &matrix_imag, bool hermitian) {
Matrix out;
for (Index i = 0; i < matrix_real.size(); i++) {
out.push_back({});
for (Index j = 0; j < matrix_real.size(); j++) {
if (hermitian)
out.back().push_back({stod(matrix_real[j][i]), -stod(matrix_imag[j][i])});
else
out.back().push_back({stod(matrix_real[i][j]), stod(matrix_imag[i][j])});
}
}
return out;
}
ParameterResolver ParameterResolverConverter(const ParaNameType &para_name, const CoeffType &coeff,
const RequireType &require_grad, bool hermitian) {
ParameterResolver pr;
for (Index i = 0; i < para_name.size(); i++) {
auto name = para_name[i];
if (hermitian)
pr.SetData(name, -coeff[i]);
else
pr.SetData(name, coeff[i]);
if (require_grad[i])
pr.InsertRequiresGrad(name);
else
pr.InsertNoGrad(name);
}
return pr;
}
std::vector<BasicCircuit> CircuitTransfor(const NamesType &names, const ComplexMatrixsType &matrixs,
const Indexess &objs_qubits, const Indexess &ctrls_qubits,
const ParasNameType &paras_name, const CoeffsType &coeffs,
const RequiresType &requires_grad) {
BasicCircuit circuit = BasicCircuit();
BasicCircuit herm_circuit = BasicCircuit();
circuit.AppendBlock();
herm_circuit.AppendBlock();
for (Index n = 0; n < names.size(); n++) {
Indexes obj(objs_qubits[n].size());
Indexes ctrl(ctrls_qubits[n].size());
std::transform(objs_qubits[n].begin(), objs_qubits[n].end(), obj.begin(),
[](const int64_t &i) { return (Index)(i); });
std::transform(ctrls_qubits[n].begin(), ctrls_qubits[n].end(), ctrl.begin(),
[](const int64_t &i) { return (Index)(i); });
if (names[n] == "npg")
// non parameterize gate
circuit.AppendNoneParameterGate("npg", MatrixConverter(matrixs[n][0], matrixs[n][1], false), obj, ctrl);
else
circuit.AppendParameterGate(names[n], obj, ctrl,
ParameterResolverConverter(paras_name[n], coeffs[n], requires_grad[n], false));
}
for (Index n = 0; n < names.size(); n++) {
Index tail = names.size() - 1 - n;
Indexes obj(objs_qubits[tail].size());
Indexes ctrl(ctrls_qubits[tail].size());
std::transform(objs_qubits[tail].begin(), objs_qubits[tail].end(), obj.begin(),
[](const int64_t &i) { return (Index)(i); });
std::transform(ctrls_qubits[tail].begin(), ctrls_qubits[tail].end(), ctrl.begin(),
[](const int64_t &i) { return (Index)(i); });
if (names[tail] == "npg")
// non parameterize gate
herm_circuit.AppendNoneParameterGate("npg", MatrixConverter(matrixs[tail][0], matrixs[tail][1], true), obj, ctrl);
else
herm_circuit.AppendParameterGate(
names[tail], obj, ctrl, ParameterResolverConverter(paras_name[tail], coeffs[tail], requires_grad[tail], true));
}
return {circuit, herm_circuit};
}
Hamiltonians HamiltoniansTransfor(const PaulisCoeffsType &paulis_coeffs, const PaulisWordsType &paulis_words,
const PaulisQubitsType &paulis_qubits) {
Hamiltonians hams;
for (Index n = 0; n < paulis_coeffs.size(); n++) {
Hamiltonian ham;
Simulator::TermsDict td;
for (Index i = 0; i < paulis_coeffs[n].size(); i++) {
Simulator::Term term;
for (Index j = 0; j < paulis_words[n][i].size(); j++)
term.push_back(std::make_pair((Index)(paulis_qubits[n][i][j]), paulis_words[n][i][j].at(0)));
td.push_back(std::make_pair(term, paulis_coeffs[n][i]));
}
ham.SetTermsDict(td);
hams.push_back(ham);
}
return hams;
}
} // namespace transformer
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_TRANSFORMER_H_
#define MINDQUANTUM_ENGINE_TRANSFORMER_H_
#include <vector>
#include <string>
#include "mindquantum/gates/gates.h"
#include "mindquantum/circuit.h"
#include "mindquantum/utils.h"
#include "mindquantum/parameter_resolver.h"
#include "mindquantum/hamiltonian.h"
namespace mindspore {
namespace mindquantum {
namespace transformer {
using NameType = std::string;
using MatrixColumnType = std::vector<std::string>;
using MatrixType = std::vector<MatrixColumnType>;
using ComplexMatrixType = std::vector<MatrixType>;
using ParaNameType = std::vector<std::string>;
using CoeffType = std::vector<float>;
using RequireType = std::vector<bool>;
using NamesType = std::vector<NameType>;
using ComplexMatrixsType = std::vector<ComplexMatrixType>;
using ParasNameType = std::vector<ParaNameType>;
using CoeffsType = std::vector<CoeffType>;
using RequiresType = std::vector<RequireType>;
using Indexess = std::vector<std::vector<int64_t>>;
using PauliCoeffsType = std::vector<float>;
using PaulisCoeffsType = std::vector<PauliCoeffsType>;
using PauliWordType = std::vector<std::string>;
using PauliWordsType = std::vector<PauliWordType>;
using PaulisWordsType = std::vector<PauliWordsType>;
using PauliQubitType = std::vector<int64_t>;
using PauliQubitsType = std::vector<PauliQubitType>;
using PaulisQubitsType = std::vector<PauliQubitsType>;
using Hamiltonians = std::vector<Hamiltonian>;
Hamiltonians HamiltoniansTransfor(const PaulisCoeffsType &, const PaulisWordsType &, const PaulisQubitsType &);
std::vector<BasicCircuit> CircuitTransfor(const NamesType &, const ComplexMatrixsType &, const Indexess &,
const Indexess &, const ParasNameType &, const CoeffsType &,
const RequiresType &);
Matrix MatrixConverter(const MatrixType &, const MatrixType &, bool);
ParameterResolver ParameterResolverConverter(const ParaNameType &, const CoeffType &, const RequireType &, bool);
} // namespace transformer
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_TRANSFORMER_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindquantum/utils.h"
namespace mindspore {
namespace mindquantum {
ComplexType ComplexInnerProduct(const Simulator::StateVector &v1, const Simulator::StateVector &v2, unsigned len) {
CalcType real_part = 0;
CalcType imag_part = 0;
#pragma omp parallel for reduction(+ : real_part, imag_part)
for (Index i = 0; i < len; i++) {
real_part += v1[i].real() * v2[i].real() + v1[i].imag() * v2[i].imag();
imag_part += v1[i].real() * v2[i].imag() - v1[i].imag() * v2[i].real();
}
ComplexType result = {real_part, imag_part};
return result;
}
ComplexType ComplexInnerProductWithControl(const Simulator::StateVector &v1, const Simulator::StateVector &v2,
Index len, std::size_t ctrlmask) {
CalcType real_part = 0;
CalcType imag_part = 0;
#pragma omp parallel for reduction(+ : real_part, imag_part)
for (std::size_t i = 0; i < len; i++) {
if ((i & ctrlmask) == ctrlmask) {
real_part += v1[i].real() * v2[i].real() + v1[i].imag() * v2[i].imag();
imag_part += v1[i].real() * v2[i].imag() - v1[i].imag() * v2[i].real();
}
}
ComplexType result = {real_part, imag_part};
return result;
}
} // namespace mindquantum
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDQUANTUM_ENGINE_UTILS_H_
#define MINDQUANTUM_ENGINE_UTILS_H_
#include <string>
#include <complex>
#include <vector>
#include <map>
#include <set>
#include "projectq/backends/_sim/_cppkernels/intrin/alignedallocator.hpp"
#include "projectq/backends/_sim/_cppkernels/simulator.hpp"
namespace mindspore {
namespace mindquantum {
using CalcType = double;
using ComplexType = std::complex<CalcType>;
using ParaType = std::map<std::string, CalcType>;
using ParaSetType = std::set<std::string>;
using Matrix = std::vector<std::vector<ComplexType, aligned_allocator<ComplexType, 64>>>;
using Index = unsigned;
using Indexes = std::vector<Index>;
using ParaMapType = std::map<std::string, CalcType>;
ComplexType ComplexInnerProduct(const Simulator::StateVector &, const Simulator::StateVector &, Index);
ComplexType ComplexInnerProductWithControl(const Simulator::StateVector &, const Simulator::StateVector &, Index,
std::size_t);
const char kNThreads[] = "n_threads";
const char kNQubits[] = "n_qubits";
const char kEncoderParamsNames[] = "encoder_params_names";
const char kAnsatzParamsNames[] = "ansatz_params_names";
const char kGateNames[] = "gate_names";
const char kGateMatrix[] = "gate_matrix";
const char kGateObjQubits[] = "gate_obj_qubits";
const char kGateCtrlQubits[] = "gate_ctrl_qubits";
const char kGateParamsNames[] = "gate_params_names";
const char kGateCoeff[] = "gate_coeff";
const char kGateRequiresGrad[] = "gate_requires_grad";
const char kHamsPauliCoeff[] = "hams_pauli_coeff";
const char kHamsPauliWord[] = "hams_pauli_word";
const char kHamsPauliQubit[] = "hams_pauli_qubit";
} // namespace mindquantum
} // namespace mindspore
#endif // MINDQUANTUM_ENGINE_UTILS_H_

View File

@ -158,31 +158,32 @@ static std::string ExceptionTypeToString(ExceptionType type) {
static const char *GetSubModuleName(SubModuleId module_id) {
static const char *sub_module_names[NUM_SUBMODUES] = {
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT" // SM_HCCL_ADPT
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT", // SM_HCCL_ADPT
"MINDQUANTUM" // SM_MINDQUANTUM
};
return sub_module_names[module_id % NUM_SUBMODUES];
@ -425,7 +426,7 @@ class LogConfigParser {
bool ParseLogLevel(const std::string &str_level, MsLogLevel *ptr_level) {
if (str_level.size() == 1) {
int ch = str_level.c_str()[0];
ch = ch - '0'; // substract ASCII code of '0', which is 48
ch = ch - '0'; // subtract ASCII code of '0', which is 48
if (ch >= DEBUG && ch <= ERROR) {
if (ptr_level != nullptr) {
*ptr_level = static_cast<MsLogLevel>(ch);
@ -444,7 +445,7 @@ static MsLogLevel GetGlobalLogLevel() {
auto str_level = GetEnv("GLOG_v");
if (str_level.size() == 1) {
int ch = str_level.c_str()[0];
ch = ch - '0'; // substract ASCII code of '0', which is 48
ch = ch - '0'; // subtract ASCII code of '0', which is 48
if (ch >= DEBUG && ch <= ERROR) {
log_level = ch;
}

View File

@ -126,6 +126,7 @@ enum SubModuleId : int {
SM_PS, // Parameter Server
SM_LITE, // LITE
SM_HCCL_ADPT, // Hccl Adapter
SM_MINDQUANTUM, // MindQuantum
NUM_SUBMODUES // number of submodules
};

View File

@ -16,6 +16,7 @@
"""Generate bprop for other ops"""
from .. import operations as P
from .. import composite as C
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
@ -46,3 +47,22 @@ def get_bprop_iou(self):
def bprop(x, y, out, dout):
return zeros_like(x), zeros_like(y)
return bprop
@bprop_getters.register(P.PQC)
def bprop_pqc(self):
"""Generate bprop for PQC"""
t = P.Transpose()
mul = P.Mul()
sum_ = P.ReduceSum()
def bprop(encoder_data, ansatz_data, out, dout):
dx = t(out[1], (2, 0, 1))
dx = mul(dout[0], dx)
dx = sum_(dx, 2)
dx = t(dx, (1, 0))
dy = C.tensor_dot(dout[0], out[2], ((0, 1), (0, 1)))
return dx, dy
return bprop

View File

@ -98,6 +98,7 @@ from .sparse_ops import SparseToDense
from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx,
SubAndFilter,
MapUniform, DynamicAssign, PadAndShift)
from .quantum_ops import PQC
__all__ = [
'Unique',
@ -419,6 +420,7 @@ __all__ = [
"MatrixInverse",
"Range",
"IndexAdd",
"PQC",
]
__all__.sort()

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for quantum computing."""
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
class PQC(PrimitiveWithInfer):
r"""
Evaluate a parameterized quantum circuit and calculate the gradient of each parameters.
Inputs of this operation is generated by MindQuantum framework.
Inputs:
- **n_qubits** (int) - The qubit number of quantum simulator.
- **encoder_params_names** (List[str]) - The parameters names of encoder circuit.
- **ansatz_params_names** (List[str]) - The parameters names of ansatz circuit.
- **gate_names** (List[str]) - The name of each gate.
- **gate_matrix** (List[List[List[List[float]]]]) - Real part and image
part of the matrix of quantum gate.
- **gate_obj_qubits** (List[List[int]]) - Object qubits of each gate.
- **gate_ctrl_qubits** (List[List[int]]) - Control qubits of each gate.
- **gate_params_names** (List[List[str]]) - Parameter names of each gate.
- **gate_coeff** (List[List[float]]) - Coefficient of eqch parameter of each gate.
- **gate_requires_grad** (List[List[bool]]) - Whether to calculate gradient
of parameters of gates.
- **hams_pauli_coeff** (List[List[float]]) - Coefficient of pauli words.
- **hams_pauli_word** (List[List[List[str]]]) - Pauli words.
- **hams_pauli_qubit** (List[List[List[int]]]) - The qubit that pauli matrix act on.
- **n_threads** (int) - Thread to evaluate input data.
Outputs:
- **expected_value** (Tensor) - The expected value of hamiltonian.
- **g1** (Tensor) - Gradient of encode circuit parameters.
- **g2** (Tensor) - Gradient of ansatz circuit parameters.
Supported Platforms:
``CPU``
"""
@prim_attr_register
def __init__(self, n_qubits, encoder_params_names, ansatz_params_names,
gate_names, gate_matrix, gate_obj_qubits, gate_ctrl_qubits,
gate_params_names, gate_coeff, gate_requires_grad,
hams_pauli_coeff, hams_pauli_word, hams_pauli_qubit,
n_threads):
self.init_prim_io_names(
inputs=['encoder_data', 'ansatz_data'],
outputs=['results', 'encoder_gradient', 'ansatz_gradient'])
self.n_hams = len(hams_pauli_coeff)
def check_shape_size(self, encoder_data, ansatz_data):
if len(encoder_data) != 2:
raise ValueError(
"PQC input encoder_data should have dimension size \
equal to 2, but got {}.".format(len(encoder_data)))
if len(ansatz_data) != 1:
raise ValueError(
"PQC input ansatz_data should have dimension size \
equal to 1, but got {}.".format(len(ansatz_data)))
def infer_shape(self, encoder_data, ansatz_data):
self.check_shape_size(encoder_data, ansatz_data)
return [encoder_data[0], self.n_hams], [
encoder_data[0], self.n_hams,
len(self.encoder_params_names)
], [encoder_data[0], self.n_hams,
len(self.ansatz_params_names)]
def infer_dtype(self, encoder_data, ansatz_data):
args = {'encoder_data': encoder_data, 'ansatz_data': ansatz_data}
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type,
self.name)
return encoder_data, encoder_data, encoder_data

View File

@ -2753,7 +2753,63 @@ test_case_quant_ops = [
'skip': ['backward']}),
]
test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_other_ops, test_case_quant_ops]
test_case_quantum_ops = [
('PQC', {
'block': P.PQC(n_qubits=3,
encoder_params_names=['e0', 'e1', 'e2'],
ansatz_params_names=['a', 'b', 'c'],
gate_names=['RX', 'RX', 'RX', 'npg', 'npg',
'npg', 'RX', 'npg', 'npg', 'RZ',
'npg', 'npg', 'RY'],
gate_matrix=[[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.7071067811865475', '0.7071067811865475'],
['0.7071067811865475', '-0.7071067811865475']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.7071067811865475', '0.7071067811865475'],
['0.7071067811865475', '-0.7071067811865475']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.7071067811865475', '0.7071067811865475'],
['0.7071067811865475', '-0.7071067811865475']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '1.0'], ['1.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '-0.0'], ['0.0', '0.0']],
[['0.0', '-1.0'], ['1.0', '0.0']]],
[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '1.0'], ['1.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['1.0', '0.0'], ['0.0', '-1.0']],
[['0.0', '0.0'], ['0.0', '0.0']]],
[[['0.0', '0.0'], ['0.0', '0.0']],
[['0.0', '0.0'], ['0.0', '0.0']]]],
gate_obj_qubits=[[0], [1], [2], [0], [1], [2],
[0], [1], [2], [1], [1], [0], [2]],
gate_ctrl_qubits=[[], [], [], [], [], [], [], [], [], [], [2], [], []],
gate_params_names=[['e0'], ['e1'], ['e2'], [], [], [], ['a'], [], [],
['b'], [], [], ['c']],
gate_coeff=[[1.0], [1.0], [1.0], [], [], [], [1.0], [], [], [1.0], [],
[], [1.0]],
gate_requires_grad=[[True], [True], [True], [], [], [], [True], [], [],
[True], [], [], [True]],
hams_pauli_coeff=[[1.0]],
hams_pauli_word=[[['X', 'Y', 'Z']]],
hams_pauli_qubit=[[[0, 1, 2]]],
n_threads=1),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]).astype(np.float32)),
Tensor(np.array([2.0, 3.0, 4.0]).astype(np.float32))],
'skip': ['backward']}),
]
test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops,
test_case_other_ops, test_case_quant_ops, test_case_quantum_ops]
test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm

View File

@ -0,0 +1,41 @@
--- ProjectQ-0.5.1/projectq/backends/_sim/_cppkernels/simulator.hpp 2020-06-05 21:07:57.000000000 +0800
+++ ProjectQ-0.5.1_new/projectq/backends/_sim/_cppkernels/simulator.hpp 2021-01-14 10:52:24.822039389 +0800
@@ -33,7 +33,6 @@
#include <random>
#include <functional>
-
class Simulator{
public:
using calc_type = double;
@@ -44,8 +43,9 @@ public:
using Term = std::vector<std::pair<unsigned, char>>;
using TermsDict = std::vector<std::pair<Term, calc_type>>;
using ComplexTermsDict = std::vector<std::pair<Term, complex_type>>;
+ StateVector vec_;
- Simulator(unsigned seed = 1) : N_(0), vec_(1,0.), fusion_qubits_min_(4),
+ Simulator(unsigned seed = 1) : vec_(1,0.), N_(0), fusion_qubits_min_(4),
fusion_qubits_max_(5), rnd_eng_(seed) {
vec_[0]=1.; // all-zero initial state
std::uniform_real_distribution<double> dist(0., 1.);
@@ -562,7 +562,6 @@ private:
}
unsigned N_; // #qubits
- StateVector vec_;
Map map_;
Fusion fused_gates_;
unsigned fusion_qubits_min_, fusion_qubits_max_;
@@ -570,10 +569,8 @@ private:
std::function<double()> rng_;
// large array buffers to avoid costly reallocations
- static StateVector tmpBuff1_, tmpBuff2_;
+ StateVector tmpBuff1_, tmpBuff2_;
};
-Simulator::StateVector Simulator::tmpBuff1_;
-Simulator::StateVector Simulator::tmpBuff2_;
#endif