Skip to content

Commit

Permalink
Release 2.22.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Lienemann committed Dec 21, 2023
1 parent 5ae6aee commit 310364d
Show file tree
Hide file tree
Showing 73 changed files with 1,789 additions and 799 deletions.
2 changes: 1 addition & 1 deletion laboneq/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.21.0
2.22.0
66 changes: 43 additions & 23 deletions laboneq/compiler/code_generator/analyze_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ class IntervalStartEvent:
play_wave_id: str
acquisition_type: list
acquire_handle: str
oscillator_frequency: float
amplitude: float
play_pulse_parameters: Optional[Dict[str, Any]]
pulse_pulse_parameters: Optional[Dict[str, Any]]
channels: list[int | list[int]]
Expand All @@ -382,7 +384,7 @@ class IntervalEndEvent:
time: float
play_wave_id: str

interval_zip = list(
interval_zip: list[tuple[IntervalStartEvent, IntervalEndEvent]] = list(
zip(
[
IntervalStartEvent(
Expand All @@ -391,6 +393,8 @@ class IntervalEndEvent:
event["play_wave_id"],
event.get("acquisition_type", []),
event["acquire_handle"],
event.get("oscillator_frequency", 0.0),
event.get("amplitude", 1.0),
event.get("play_pulse_parameters"),
event.get("pulse_pulse_parameters"),
event.get("channel") or channels,
Expand Down Expand Up @@ -444,6 +448,8 @@ class IntervalEndEvent:
"play_wave_id": interval_start.play_wave_id,
"acquisition_type": interval_start.acquisition_type,
"acquire_handles": [interval_start.acquire_handle],
"oscillator_frequency": interval_start.oscillator_frequency,
"amplitude": interval_start.amplitude,
"feedback_register": feedback_register,
"channels": interval_start.channels,
"play_pulse_parameters": interval_start.play_pulse_parameters,
Expand Down Expand Up @@ -549,14 +555,19 @@ def analyze_prng_times(events, sampling_rate, delay):
filtered_events = (
(index, event)
for index, event in enumerate(events)
if event["event_type"] == EventType.SETUP_PRNG
if event["event_type"]
in (
EventType.PRNG_SETUP,
EventType.DROP_PRNG_SETUP,
EventType.DRAW_PRNG_SAMPLE,
EventType.DROP_PRNG_SAMPLE,
)
)
for index, event in filtered_events:
event_time_in_samples = length_to_samples(event["time"] + delay, sampling_rate)
retval.add(
event_time_in_samples,
AWGEvent(
type=AWGEventType.SEED_PRNG,
if event["event_type"] == EventType.PRNG_SETUP:
awg_event = AWGEvent(
type=AWGEventType.SETUP_PRNG,
start=event_time_in_samples,
end=event_time_in_samples,
priority=index,
Expand All @@ -565,25 +576,34 @@ def analyze_prng_times(events, sampling_rate, delay):
"seed": event["seed"],
"section": event["section_name"],
},
),
)

filtered_events = (
(index, event)
for index, event in enumerate(events)
if event["event_type"] == EventType.SAMPLE_PRNG
)
for index, event in filtered_events:
event_time_in_samples = length_to_samples(event["time"] + delay, sampling_rate)
retval.add(
event_time_in_samples,
AWGEvent(
type=AWGEventType.SAMPLE_PRNG,
)
elif event["event_type"] == EventType.DROP_PRNG_SETUP:
awg_event = AWGEvent(
type=AWGEventType.DROP_PRNG_SETUP,
start=event_time_in_samples,
end=event_time_in_samples,
priority=index,
params={},
),
)
params={"section": event["section_name"]},
)
elif event["event_type"] == EventType.DRAW_PRNG_SAMPLE:
awg_event = AWGEvent(
type=AWGEventType.PRNG_SAMPLE,
start=event_time_in_samples,
end=event_time_in_samples,
priority=index,
params={
"sample_name": event["sample_name"],
"section_name": event["section_name"],
},
)
else: # EventType.DROP_PRNG_SAMPLE
awg_event = AWGEvent(
type=AWGEventType.DROP_PRNG_SAMPLE,
start=event_time_in_samples,
end=event_time_in_samples,
priority=index,
params={"sample_name": event["sample_name"]},
)
retval.add(event_time_in_samples, awg_event)

return retval
96 changes: 72 additions & 24 deletions laboneq/compiler/code_generator/analyze_playback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import math
from bisect import bisect_left
from dataclasses import dataclass
from math import ceil
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from intervaltree import Interval, IntervalTree
Expand Down Expand Up @@ -102,7 +103,7 @@ class FeedbackIntervalData:
handle: str | None
local: bool | None
user_register: int | None
match_prng: bool
prng_sample: str | None


def _analyze_branches(events, delay, sampling_rate, playwave_max_hint):
Expand All @@ -117,8 +118,12 @@ def _analyze_branches(events, delay, sampling_rate, playwave_max_hint):
if ev["event_type"] == "SECTION_START":
handle = ev.get("handle", None)
user_register = ev.get("user_register", None)
match_prng = ev.get("match_prng", False)
if handle is not None or user_register is not None or match_prng:
prng_sample = ev.get("prng_sample")
if (
handle is not None
or user_register is not None
or prng_sample is not None
):
begin = length_to_samples(ev["time"] + delay, sampling_rate)
# Add the command table interval boundaries as cut points
cut_points.add(begin)
Expand All @@ -129,7 +134,7 @@ def _analyze_branches(events, delay, sampling_rate, playwave_max_hint):
begin=begin,
end=None,
data=FeedbackIntervalData(
handle, ev.get("local"), user_register, match_prng
handle, ev.get("local"), user_register, prng_sample
),
)
)
Expand Down Expand Up @@ -186,7 +191,7 @@ def _interval_list(events, states, signal_ids, delay, sub_channel):
signal_id=cur_signal_id,
time=start["time"] + delay,
play_wave_id=start["play_wave_id"],
amplitude=start["amplitude"],
amplitude=start.get("amplitude", 1.0),
index=index,
oscillator_phase=start.get("oscillator_phase"),
oscillator_frequency=start.get("oscillator_frequency"),
Expand Down Expand Up @@ -322,6 +327,17 @@ def _make_interval_tree(
return interval_tree


def _oscillator_phase_increment_points(
events, signals, delay, sampling_rate
) -> set[int]:
virtual_z_gates = _find_frame_changes(events, delay, sampling_rate)
virtual_z_gates = [v for v in virtual_z_gates if v.signal in signals]

frame_change_times = {frame_change.time for frame_change in virtual_z_gates}

return frame_change_times


@dataclass
class _FrameChange:
time: int
Expand Down Expand Up @@ -353,6 +369,22 @@ def _find_frame_changes(events, delay, sampling_rate):
return virtual_z_gates


def _make_virtual_z_gate_event(frame_change: _FrameChange, signal: SignalObj):
return AWGEvent(
type=AWGEventType.CHANGE_OSCILLATOR_PHASE,
start=frame_change.time,
end=frame_change.time,
priority=frame_change.priority,
params={
"signal": frame_change.signal,
"phase": frame_change.phase,
"oscillator": signal.hw_oscillator
if signal.awg.device_type.supports_oscillator_switching
else None,
},
)


def _insert_frame_changes(
interval_events: AWGSampledEventSequence,
events,
Expand All @@ -361,16 +393,16 @@ def _insert_frame_changes(
signals: dict[str, SignalObj],
):
frame_change_events = AWGSampledEventSequence()
frame_changes = _find_frame_changes(events, delay, sampling_rate)
frame_changes = [fc for fc in frame_changes if fc.signal in signals]

intervals = IntervalTree.from_tuples(
(ev.start, ev.end, ev)
for t, evl in interval_events.sequence.items()
for ev in evl
if ev.type == AWGEventType.PLAY_WAVE
)

frame_changes = _find_frame_changes(events, delay, sampling_rate)
frame_changes = [fc for fc in frame_changes if fc.signal in signals]

for frame_change in frame_changes:
if frame_change.signal not in signals.keys():
continue
Expand All @@ -393,20 +425,7 @@ def _insert_frame_changes(
" adding a small delay after the phase increment."
)
frame_change_events.add(
frame_change.time,
AWGEvent(
type=AWGEventType.CHANGE_OSCILLATOR_PHASE,
start=frame_change.time,
end=frame_change.time,
priority=frame_change.priority,
params={
"signal": frame_change.signal,
"phase": frame_change.phase,
"oscillator": signal.hw_oscillator
if signal.awg.device_type.supports_oscillator_switching
else None,
},
),
frame_change.time, _make_virtual_z_gate_event(frame_change, signal)
)
continue

Expand All @@ -418,6 +437,15 @@ def _insert_frame_changes(
signal.hw_oscillator != signature.hw_oscillator
and signature.hw_oscillator is not None
):
# Oscillator conflict!
# We may still get out of this, if the frame change happens at the beginning
# of the interval, by emitting a 0-length command table entry.
if frame_change.time == play_iv.data.start:
frame_change_events.add(
frame_change.time, _make_virtual_z_gate_event(frame_change, signal)
)
continue

raise LabOneQException(
f"Cannot increment oscillator '{signal.hw_oscillator}' of signal"
f" '{signal.id}': the line is occupied by '{signature.hw_oscillator}'."
Expand Down Expand Up @@ -450,6 +478,7 @@ def _oscillator_switch_cut_points(
interval_tree: IntervalTree,
signals: Dict[str, SignalObj],
sample_multiple,
oscillator_phase_increment_times: set[int],
) -> Tuple[AWGSampledEventSequence, Set]:
cut_points = set()

Expand Down Expand Up @@ -513,6 +542,13 @@ def reducer(a, b):

osc_intervals.merge_overlaps(reducer)

for time in oscillator_phase_increment_times:
# Check if the frame change happens after the last oscillator switch - the
# compiler will otherwise assume that the last oscillator switch stays valid
# to the end of the sequence.
if time >= osc_intervals.end():
cut_points.add(ceil(time / sample_multiple) * sample_multiple)

oscillator_switch_events = AWGSampledEventSequence()
for iv in osc_intervals:
osc_switch_event = AWGEvent(
Expand Down Expand Up @@ -682,10 +718,22 @@ def analyze_play_wave_times(
cut_points.add(sequence_end)
cut_points.update(other_events.sequence)

# Oscillator phase increments are determined in a two-step process.
# 1. A preliminary list of all phase increments is collected. These are used to
# determine the cut points for waveforms, as some phase increments require
# oscillator switching.
# 2. Once the waveform intervals have been established (aka compacted), the actual
# phase increments are baked into the intervals.
prelim_oscillator_phase_increments = _oscillator_phase_increment_points(
events, signals, delay, sampling_rate
)

(
oscillator_switch_events,
oscillator_switch_cut_points,
) = _oscillator_switch_cut_points(interval_tree, signals, sample_multiple)
) = _oscillator_switch_cut_points(
interval_tree, signals, sample_multiple, prelim_oscillator_phase_increments
)

cut_points.update(oscillator_switch_cut_points)
oscillator_intervals = _oscillator_intervals(
Expand Down Expand Up @@ -749,7 +797,7 @@ def analyze_play_wave_times(
"handle": interval.data.handle,
"local": interval.data.local,
"user_register": interval.data.user_register,
"match_prng": interval.data.match_prng,
"prng_sample": interval.data.prng_sample,
"signal_id": signal_id,
"section_name": section_name,
},
Expand Down
Loading

0 comments on commit 310364d

Please sign in to comment.