Skip to content

Commit

Permalink
Merge pull request #192 from Zhaoyilunnn/main
Browse files Browse the repository at this point in the history
fix: make the parameter shift example compatible with latest tq
  • Loading branch information
Hanrui-Wang authored Sep 20, 2023
2 parents 8db9260 + 744b49a commit c99b0af
Show file tree
Hide file tree
Showing 2 changed files with 5,596 additions and 5,117 deletions.
12 changes: 6 additions & 6 deletions examples/param_shift_onchip_training/param_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.layer import SethLayer0
from torchquantum.layer.layers import SethLayer0

from torchquantum.dataset import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand All @@ -39,7 +39,6 @@ class QFCModel(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_ryzxy"])

self.arch = {"n_wires": self.n_wires, "n_blocks": 2, "n_layers_per_block": 2}
Expand All @@ -49,16 +48,17 @@ def __init__(self):

def forward(self, x, use_qiskit=False):
bsz = x.shape[0]
q_device = tq.QuantumDevice(n_wires=self.n_wires, bsz=bsz)
x = F.avg_pool2d(x, 6).view(bsz, 16)

if use_qiskit:
x = self.qiskit_processor.process_parameterized(
self.q_device, self.encoder, self.q_layer, self.measure, x
q_device, self.encoder, self.q_layer, self.measure, x
)
else:
self.encoder(self.q_device, x)
self.q_layer(self.q_device)
x = self.measure(self.q_device)
self.encoder(q_device, x)
self.q_layer(q_device)
x = self.measure(q_device)

x = x.reshape(bsz, 4)

Expand Down
Loading

0 comments on commit c99b0af

Please sign in to comment.