"""
Janus 编译器优化 Pass
基础优化 pass 实现
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import numpy as np
from janus.circuit.dag import DAGCircuit, DAGNode, NodeType
[文档]
class BasePass(ABC):
"""优化 Pass 基类"""
@property
@abstractmethod
def name(self) -> str:
"""Pass 名称"""
pass
[文档]
@abstractmethod
def run(self, dag: DAGCircuit) -> DAGCircuit:
"""
执行优化
Args:
dag: 输入 DAG
Returns:
优化后的 DAG
"""
pass
[文档]
class RemoveIdentityPass(BasePass):
"""
移除恒等门
移除对电路没有影响的门,如 I 门
"""
@property
def name(self) -> str:
return "remove_identity"
[文档]
def run(self, dag: DAGCircuit) -> DAGCircuit:
nodes_to_remove = []
for node in dag.op_nodes():
if node.op and node.op.name.lower() in ('id', 'i', 'identity'):
nodes_to_remove.append(node)
for node in nodes_to_remove:
dag.remove_op_node(node)
return dag
[文档]
class CancelInversesPass(BasePass):
"""
消除相邻的逆门对
例如: X-X, H-H, CX-CX (相同控制和目标) 等
"""
# 自逆门(自己是自己的逆)
SELF_INVERSE = {
# 单比特 Pauli 门
'x', 'y', 'z', 'h',
# 两比特门
'cx', 'cy', 'cz', 'ch', 'swap', 'iswap', 'dcx', 'ecr',
# 三比特门
'ccx', 'ccz', 'cswap',
# 多控制门
'mcx', 'mcx_gray', 'mcx_recursive', 'mcx_vchain',
'c3x', 'c4x', 'rccx', 'rc3x',
}
@property
def name(self) -> str:
return "cancel_inverses"
[文档]
def run(self, dag: DAGCircuit) -> DAGCircuit:
changed = True
while changed:
changed = False
nodes_to_remove = []
for node in list(dag.op_nodes()):
if node.node_id not in dag._nodes:
continue
name = node.op.name.lower() if node.op else ""
if name not in self.SELF_INVERSE:
continue
# 检查后继节点
for succ in dag.successors(node):
if succ.node_type != NodeType.OP:
continue
succ_name = succ.op.name.lower() if succ.op else ""
# 相同门,相同量子比特
if (succ_name == name and
succ.qubits == node.qubits and
self._is_direct_successor(dag, node, succ)):
nodes_to_remove.append((node, succ))
changed = True
break
# 移除配对的门
for node1, node2 in nodes_to_remove:
if node1.node_id in dag._nodes:
dag.remove_op_node(node1)
if node2.node_id in dag._nodes:
dag.remove_op_node(node2)
return dag
def _is_direct_successor(self, dag: DAGCircuit, node1: DAGNode, node2: DAGNode) -> bool:
"""检查 node2 是否是 node1 在所有相关量子比特上的直接后继"""
for q in node1.qubits:
# 找 node1 在量子比特 q 上的直接后继
found = False
for succ in dag.successors(node1):
if q in succ.qubits and succ.node_id == node2.node_id:
found = True
break
if not found:
return False
return True
[文档]
class MergeRotationsPass(BasePass):
"""
合并连续的旋转门
例如: RZ(a) - RZ(b) -> RZ(a+b)
"""
# 单参数旋转门(可以合并角度)
ROTATION_GATES = {
# 单比特旋转门
'rx', 'ry', 'rz', 'p', 'u1',
# 两比特旋转门
'rxx', 'ryy', 'rzz', 'rzx',
# 受控旋转门
'crx', 'cry', 'crz', 'cp', 'cu1',
# 多控制旋转门
'mcrx', 'mcry', 'mcrz', 'mcp', 'mcphase', 'mcu1',
}
@property
def name(self) -> str:
return "merge_rotations"
[文档]
def run(self, dag: DAGCircuit) -> DAGCircuit:
changed = True
while changed:
changed = False
for node in list(dag.op_nodes()):
if node.node_id not in dag._nodes:
continue
name = node.op.name.lower() if node.op else ""
if name not in self.ROTATION_GATES:
continue
# 检查后继
for succ in dag.successors(node):
if succ.node_type != NodeType.OP:
continue
succ_name = succ.op.name.lower() if succ.op else ""
# 相同旋转门,相同量子比特,直接后继
if (succ_name == name and
succ.qubits == node.qubits and
len(node.qubits) == 1):
# 合并角度
angle1 = node.op.params[0] if node.op.params else 0
angle2 = succ.op.params[0] if succ.op.params else 0
new_angle = float(angle1) + float(angle2)
# 归一化到 [-2π, 2π]
new_angle = new_angle % (2 * np.pi)
# 如果角度接近 0,移除两个门
if abs(new_angle) < 1e-10 or abs(new_angle - 2*np.pi) < 1e-10:
dag.remove_op_node(succ)
dag.remove_op_node(node)
else:
# 更新第一个门的角度,移除第二个
node.op.params[0] = new_angle
dag.remove_op_node(succ)
changed = True
break
return dag