janus.encode.efficient_sparse 源代码

"""
高效稀疏编码模块

使用 Janus 电路库实现高效稀疏状态编码,逻辑同步自 Encode.cpp。
"""

from typing import Union, List, Dict, Tuple, Optional
import numpy as np
from math import asin

from janus.circuit import Circuit
from .utils import _build_state_dict
from janus.circuit.library import U3Gate


[文档] class EfficientSparseResult: """高效稀疏编码结果类,包含电路和输出比特信息"""
[文档] def __init__(self, circuit: Circuit, out_qubits: List[int]): """ 初始化编码结果 参数: circuit: 编码电路 out_qubits: 输出比特索引列表 """ self._circuit = circuit self._out_qubits = out_qubits
@property def circuit(self) -> Circuit: """获取编码电路""" return self._circuit @property def out_qubits(self) -> List[int]: """获取输出比特索引列表""" return self._out_qubits
[文档] def get_circuit(self) -> Circuit: """获取编码电路(兼容方法)""" return self._circuit
[文档] def get_out_qubits(self) -> List[int]: """获取输出比特(兼容方法)""" return self._out_qubits
[文档] def measure(self) -> Circuit: """ 在输出比特上添加测量,返回带测量的电路副本 返回: Circuit: 带测量操作的电路 """ # 创建带有足够经典比特的新电路 n_clbits = len(self._out_qubits) measured_circuit = Circuit(self._circuit.num_qubits, n_clbits) # 复制原电路的所有指令 for inst in self._circuit.instructions: measured_circuit.append(inst.operation, inst.qubits, inst.clbits) # 添加测量 for i, qubit in enumerate(self._out_qubits): measured_circuit.measure(qubit, i) return measured_circuit
def _maximizing_difference_bit_search( b_strings: List[str], dif_qubits: List[int] ) -> Tuple[int, List[str], List[str]]: """寻找最大化集合差异的比特位""" bit_index = 0 set_difference = -1 t0_res, t1_res = [], [] n_bits = len(b_strings[0]) bit_search_space = [i for i in range(n_bits) if i not in dif_qubits] for bit in bit_search_space: temp_t0 = [s for s in b_strings if s[bit] == "0"] temp_t1 = [s for s in b_strings if s[bit] == "1"] if temp_t0 and temp_t1: temp_difference = abs(len(temp_t0) - len(temp_t1)) if set_difference == -1 or temp_difference > set_difference: t0_res, t1_res = temp_t0, temp_t1 bit_index = bit set_difference = temp_difference return bit_index, t0_res, t1_res def _build_bit_string_set( b_strings: List[str], bitstr1: str, dif_qubits: List[int], dif_values: List[int] ) -> List[str]: """构建满足特定比特值条件的字符串集合""" bit_string_set = [] for b_string in b_strings: if b_string == bitstr1: continue match = True for i, qubit_idx in enumerate(dif_qubits): if int(b_string[qubit_idx]) != dif_values[i]: match = False break if match: bit_string_set.append(b_string) return bit_string_set def _bit_string_search( b_strings: List[str], dif_qubits: List[int], dif_values: List[int] ) -> List[str]: """搜索满足条件的单一比特字符串""" temp_strings = b_strings[:] while len(temp_strings) > 1: bit, t0, t1 = _maximizing_difference_bit_search(temp_strings, dif_qubits) if bit not in dif_qubits: dif_qubits.append(bit) if len(t0) < len(t1): dif_values.append(0) temp_strings = t0 else: dif_values.append(1) temp_strings = t1 return temp_strings def _search_bit_strings_for_merging( state: Dict[str, complex], ) -> Tuple[str, str, int, List[int]]: """寻找待合并的两个比特字符串及差异位""" b_strings = list(state.keys()) dif_qubits = [] dif_values = [] if len(b_strings) == 2: bit, t0, t1 = _maximizing_difference_bit_search(b_strings, []) return t1[0], t0[0], bit, [] else: # 寻找第一个字符串 res1 = _bit_string_search(b_strings, dif_qubits, dif_values) bitstr1 = res1[0] # 弹出最后一个差异位,用于后续寻找第二个字符串 dif_qubit = dif_qubits.pop() dif_values.pop() # 寻找第二个字符串 b_strings2 = _build_bit_string_set(b_strings, bitstr1, dif_qubits, dif_values) res2 = _bit_string_search(b_strings2, dif_qubits, dif_values) bitstr2 = res2[0] return bitstr1, bitstr2, dif_qubit, dif_qubits def _apply_x_to_bit_string(b_string: str, qubit_idx: int) -> str: """对比特字符串应用 X 门效果""" s_list = list(b_string) s_list[qubit_idx] = "1" if s_list[qubit_idx] == "0" else "0" return "".join(s_list) def _apply_cx_to_bit_string(b_string: str, control: int, target: int) -> str: """对比特字符串应用 CX 门效果""" if b_string[control] == "1": return _apply_x_to_bit_string(b_string, target) return b_string def _update_state_dict( state: Dict[str, complex], operation: str, qubit_idx: Optional[int] = None, control: Optional[int] = None, target: Optional[int] = None, merge_strings: Optional[Tuple[str, str]] = None, ) -> Dict[str, complex]: """更新状态字典以反映门操作或合并""" new_state = {} if operation == "merge": if merge_strings: s1, s2 = merge_strings amp1, amp2 = state[s1], state[s2] norm = (abs(amp1) ** 2 + abs(amp2) ** 2) ** 0.5 new_state = state.copy() del new_state[s2] new_state[s1] = complex(norm) elif operation == "x": for k, v in state.items(): new_state[_apply_x_to_bit_string(k, qubit_idx)] = v elif operation == "cx": for k, v in state.items(): new_state[_apply_cx_to_bit_string(k, control, target)] = v return new_state def _compute_angles(amp1: complex, amp2: complex) -> Tuple[float, float, float]: """计算 U3 旋转角度""" norm = (abs(amp1) ** 2 + abs(amp2) ** 2) ** 0.5 if norm < 1e-14: return (0.0, 0.0, 0.0) # 逻辑同步自 Encode.cpp:2266 theta = 2 * asin(max(0.0, min(1.0, abs(amp2 / norm)))) # 使用 np.angle 获取相位 phi = -np.angle(amp2 / norm) lam = -np.angle(amp1 / norm) - phi return (float(theta), float(phi), float(lam)) def _merging_procedure( state: Dict[str, complex], circuit: Circuit, q_indices: List[int] ) -> Dict[str, complex]: # 1. 搜索待合并的比特串 bitstr1, bitstr2, dif, dif_qubits = _search_bit_strings_for_merging(state) # 2. 预处理 (同步自 _preprocess_states_for_merging) # 确保 bitstr1 在 dif 位为 '1' if bitstr1[dif] != "1": circuit.x(q_indices[dif]) bitstr1 = _apply_x_to_bit_string(bitstr1, dif) bitstr2 = _apply_x_to_bit_string(bitstr2, dif) state = _update_state_dict(state, "x", qubit_idx=dif) # 使两个比特串在除 dif 以外的位上相同 (同步自 _equalize_bit_string_states) for i in range(len(bitstr1)): if i != dif and bitstr1[i] != bitstr2[i]: circuit.cx(q_indices[dif], q_indices[i]) bitstr1 = _apply_cx_to_bit_string(bitstr1, dif, i) bitstr2 = _apply_cx_to_bit_string(bitstr2, dif, i) state = _update_state_dict(state, "cx", control=dif, target=i) # 将 bitstr2 在 dif_qubits 位上设为 '1' 以满足控制条件 (同步自 _apply_not_gates_to_qubit_index_list) for b_idx in dif_qubits: if bitstr2[b_idx] != "1": circuit.x(q_indices[b_idx]) bitstr1 = _apply_x_to_bit_string(bitstr1, b_idx) bitstr2 = _apply_x_to_bit_string(bitstr2, b_idx) state = _update_state_dict(state, "x", qubit_idx=b_idx) # 3. 计算并应用受控旋转 angles = _compute_angles(state[bitstr1], state[bitstr2]) control_qubits = [q_indices[i] for i in dif_qubits] if not control_qubits: circuit.u3(*angles, q_indices[dif]) else: # 使用 Janus 的受控门 circuit.gate(U3Gate(*angles), q_indices[dif]).control(control_qubits) # 4. 更新状态字典 state = _update_state_dict(state, "merge", merge_strings=(bitstr1, bitstr2)) return state
[文档] def efficient_sparse( q_size: int, data: Union[List[float], List[complex], Dict[str, Union[float, complex]]], add_measure: bool = False, ) -> EfficientSparseResult: """ 高效稀疏编码 (Efficient Sparse Encoding) 通过一系列合并操作将状态简化为基态,然后对整个电路求逆。 参数: q_size: 可用的量子比特总数 data: 量子态数据(列表、数组或字典格式) add_measure: 是否在输出比特上添加测量(默认 False) 返回: EfficientSparseResult: 包含电路和输出比特信息的结果对象 - result.circuit: 编码电路 - result.out_qubits: 输出比特索引列表 - result.measure(): 返回带测量的电路副本 """ # 输入转换 if isinstance(data, (list, np.ndarray)): state = _build_state_dict(data) elif isinstance(data, dict): state = {k: complex(v) for k, v in data.items()} else: raise TypeError("输入数据必须是列表、数组或字典") if not state: raise ValueError("输入数据不能为空") # 验证归一化 tmp_sum = sum(abs(amp) ** 2 for amp in state.values()) if abs(1.0 - tmp_sum) > 1e-13: if tmp_sum < 1e-13: raise ValueError("输入向量为零") # 自动归一化 factor = np.sqrt(tmp_sum) state = {k: v / factor for k, v in state.items()} first_key = next(iter(state.keys())) n_qubits = len(first_key) if n_qubits > q_size: raise ValueError(f"需要 {n_qubits} 个量子比特,但只有 {q_size} 个可用") # 这里的 q_indices 对应 C++ 中的 reverse_q # Encode.cpp:1836 reverse_q[i] = q[q.size()-1-i] q_indices = [q_size - 1 - i for i in range(q_size)] circuit = Circuit(q_size) # 循环合并直到只剩一个状态 current_state = state.copy() while len(current_state) > 1: current_state = _merging_procedure(current_state, circuit, q_indices) # 处理最后一个基态对应的 X 门 final_bitstr = next(iter(current_state.keys())) for i, bit in enumerate(final_bitstr): if bit == "1": circuit.x(q_indices[i]) # 全局求逆得到制备电路 final_circuit = circuit.inverse() # 计算输出比特,逻辑同步自 Encode.cpp:1856-1858 # for (int i = n_qubits - 1; i >= 0; --i) { # m_out_qubits.push_back(reverse_q[i]); # } # 即:reverse_q[n_qubits-1], reverse_q[n_qubits-2], ..., reverse_q[0] out_qubits = [q_indices[i] for i in range(n_qubits - 1, -1, -1)] # 如果需要添加测量 if add_measure: for i, qubit in enumerate(out_qubits): final_circuit.measure(qubit, i) return EfficientSparseResult(final_circuit, out_qubits)