Skip to content

Commit

Permalink
Merge pull request #1385 from anarkiwi/infb
Browse files Browse the repository at this point in the history
refactor to pipeline blocks
  • Loading branch information
anarkiwi authored Aug 19, 2024
2 parents dad1889 + 1c55117 commit 2cc2318
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 112 deletions.
218 changes: 122 additions & 96 deletions gamutrf/grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def __init__(
logging.warn(">1s dwell time in stare mode, updates will be slow!")
peak_fft_range = min(peak_fft_range, tune_step_fft)

fft_dir = ""
if write_fft_points:
fft_dir = sample_dir

self.sources, cmd_port, self.workaround_start_hook = get_source(
sdr,
samp_rate,
Expand All @@ -175,7 +179,14 @@ def __init__(
dc_ettus_auto_offset=dc_ettus_auto_offset,
)

fft_batch_size, self.fft_blocks = self.get_fft_blocks(
(
fft_batch_size,
self.retune_pre_fft,
self.retune_fft,
self.db_block,
self.sample_block,
self.pipeline_blocks,
) = self.get_pipeline_blocks(
samp_rate,
tune_jitter_hz,
vkfft,
Expand All @@ -194,15 +205,37 @@ def __init__(
dc_block_len,
dc_block_long,
correct_iq,
scaling,
db_clamp_floor,
db_clamp_ceil,
fft_dir,
write_samples,
bucket_range,
description,
rotate_secs,
peak_fft_range,
)
self.fft_blocks = self.fft_blocks + self.get_db_blocks(nfft, samp_rate, scaling)
self.last_db_block = self.fft_blocks[-1]
fft_dir = ""
fft_zmq_block_addr = f"tcp://{fft_zmq_addr}:{fft_zmq_port}"
self.pduzmq_block = pduzmq(fft_zmq_block_addr)
logging.info("serving FFT on %s", fft_zmq_block_addr)

if iq_zmq_port:
iq_zmq_block_addr = f"tcp://{iq_zmq_addr}:{iq_zmq_port}"
logging.info("serving I/Q samples and tags on %s", iq_zmq_block_addr)
iq_zmq_block = zeromq.pub_sink(
gr.sizeof_gr_complex,
fft_batch_size * nfft,
iq_zmq_block_addr,
100,
True,
65536,
"",
)
self.connect((self.sample_block, 0), (iq_zmq_block, 0))

self.samples_blocks = []
self.write_samples_block = None
if write_samples:
if write_fft_points:
fft_dir = sample_dir
Path(sample_dir).mkdir(parents=True, exist_ok=True)
samples_vlen = fft_batch_size * nfft
self.samples_blocks.extend(
Expand Down Expand Up @@ -233,49 +266,6 @@ def __init__(
)
self.write_samples_block = self.samples_blocks[-1]

retune_fft = self.iqtlabs.retune_fft(
tag="rx_freq",
nfft=nfft,
samp_rate=int(samp_rate),
tune_jitter_hz=int(tune_jitter_hz),
freq_start=int(freq_start),
freq_end=int(freq_end),
tune_step_hz=tune_step_hz,
tune_step_fft=tune_step_fft,
skip_tune_step_fft=skip_tune_step,
fft_min=db_clamp_floor,
fft_max=db_clamp_ceil,
sdir=fft_dir,
write_step_fft=write_samples,
bucket_range=bucket_range,
tuning_ranges=tuning_ranges,
description=description,
rotate_secs=rotate_secs,
pre_fft=pretune,
tag_now=self.tag_now,
low_power_hold_down=(not pretune and low_power_hold_down),
slew_rx_time=slew_rx_time,
peak_fft_range=peak_fft_range,
)
self.fft_blocks.append(retune_fft)
fft_zmq_block_addr = f"tcp://{fft_zmq_addr}:{fft_zmq_port}"
self.pduzmq_block = pduzmq(fft_zmq_block_addr)
logging.info("serving FFT on %s", fft_zmq_block_addr)

if iq_zmq_port:
iq_zmq_block_addr = f"tcp://{iq_zmq_addr}:{iq_zmq_port}"
logging.info("serving I/Q samples and tags on %s", iq_zmq_block_addr)
iq_zmq_block = zeromq.pub_sink(
gr.sizeof_gr_complex,
fft_batch_size * nfft,
iq_zmq_block_addr,
100,
True,
65536,
"",
)
self.connect((self.retune_pre_fft, 0), (iq_zmq_block, 0))

self.inference_blocks = []
self.inference_output_block = None
self.image_inference_block = None
Expand All @@ -284,20 +274,17 @@ def __init__(
if inference_output_dir:
Path(inference_output_dir).mkdir(parents=True, exist_ok=True)

if inference_text_color:
wc = webcolors.name_to_rgb(inference_text_color, "css3")
inference_text_color = ",".join(
[str(c) for c in [wc.blue, wc.green, wc.red]]
)

if (inference_model_server and inference_model_name) or inference_output_dir:
x = 640
y = 640
if inference_text_color:
wc = webcolors.name_to_rgb(inference_text_color, "css3")
inference_text_color = ",".join(
[str(c) for c in [wc.blue, wc.green, wc.red]]
)
self.image_inference_block = self.iqtlabs.image_inference(
tag="rx_freq",
vlen=nfft,
x=x,
y=y,
x=640,
y=640,
image_dir=inference_output_dir,
convert_alpha=255,
norm_alpha=0,
Expand Down Expand Up @@ -343,7 +330,6 @@ def __init__(
)

# TODO: provide new block that receives JSON-over-PMT and outputs to MQTT/zmq.
retune_fft_output_block = None
if self.inference_blocks:
inference_zmq_addr = f"tcp://{inference_addr}:{inference_port}"
self.inference_output_block = inferenceoutput(
Expand All @@ -359,48 +345,46 @@ def __init__(
inference_output_dir,
)
if self.iq_inference_block:
iq_inference_blocks = [self.iq_inference_block]
if iq_inference_squelch_db is not None:
squelch_blocks = self.wrap_batch(
[
analog.pwr_squelch_cc(
iq_inference_squelch_db,
iq_inference_squelch_alpha,
0,
False,
)
],
fft_batch_size,
nfft,
) + [self.iq_inference_block]
self.connect_blocks(self.retune_pre_fft, squelch_blocks)
else:
self.connect((self.retune_pre_fft, 0), (self.iq_inference_block, 0))
self.connect((self.last_db_block, 0), (self.iq_inference_block, 1))
self.iq_inference_block = (
self.wrap_batch(
[
analog.pwr_squelch_cc(
iq_inference_squelch_db,
iq_inference_squelch_alpha,
0,
False,
)
],
fft_batch_size,
nfft,
)
+ iq_inference_blocks
)
self.connect_blocks(self.sample_block, iq_inference_blocks)
self.connect((self.db_block, 0), (self.iq_inference_block, 1))
if self.image_inference_block:
if stare:
self.connect(
(self.last_db_block, 0), (self.image_inference_block, 0)
)
self.connect((self.db_block, 0), (self.image_inference_block, 0))
else:
retune_fft_output_block = self.image_inference_block
# need to pass samples through retune_fft if using image inference
self.connect((self.retune_fft, 0), (self.image_inference_block, 0))
for block in self.inference_blocks:
self.msg_connect(
(block, "inference"), (self.inference_output_block, "inference")
)

if retune_fft_output_block:
self.connect((retune_fft, 0), (retune_fft_output_block, 0))

if pretune:
self.msg_connect((self.retune_pre_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((self.retune_pre_fft, "tune"), (retune_fft, "cmd"))
self.msg_connect((self.retune_pre_fft, "tune"), (self.retune_fft, "cmd"))
else:
self.msg_connect((retune_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((retune_fft, "json"), (self.pduzmq_block, "json"))
self.connect_blocks(self.sources[0], self.sources[1:])
self.msg_connect((self.retune_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((self.retune_fft, "json"), (self.pduzmq_block, "json"))

self.connect_blocks(self.sources[-1], self.fft_blocks)
self.connect_blocks(self.retune_pre_fft, self.samples_blocks)
self.connect_blocks(self.sources[0], self.sources[1:])
self.connect_blocks(self.sources[-1], self.pipeline_blocks)
self.connect_blocks(self.sample_block, self.samples_blocks)

def connect_blocks(self, source, other_blocks, last_block_port=0):
last_block = source
Expand Down Expand Up @@ -535,7 +519,7 @@ def get_dc_blocks(
)
return []

def get_fft_blocks(
def get_pipeline_blocks(
self,
samp_rate,
tune_jitter_hz,
Expand All @@ -555,14 +539,23 @@ def get_fft_blocks(
dc_block_len,
dc_block_long,
correct_iq,
scaling,
db_clamp_floor,
db_clamp_ceil,
fft_dir,
write_samples,
bucket_range,
description,
rotate_secs,
peak_fft_range,
):
fft_batch_size, fft_blocks = self.get_offload_fft_blocks(
vkfft,
fft_batch_size,
nfft,
fft_processor_affinity,
)
self.retune_pre_fft = self.get_pretune_block(
retune_pre_fft = self.get_pretune_block(
fft_batch_size,
nfft,
samp_rate,
Expand All @@ -577,13 +570,46 @@ def get_fft_blocks(
low_power_hold_down,
slew_rx_time,
)
retune_fft = self.iqtlabs.retune_fft(
tag="rx_freq",
nfft=nfft,
samp_rate=int(samp_rate),
tune_jitter_hz=int(tune_jitter_hz),
freq_start=int(freq_start),
freq_end=int(freq_end),
tune_step_hz=tune_step_hz,
tune_step_fft=tune_step_fft,
skip_tune_step_fft=skip_tune_step,
fft_min=db_clamp_floor,
fft_max=db_clamp_ceil,
sdir=fft_dir,
write_step_fft=write_samples,
bucket_range=bucket_range,
tuning_ranges=tuning_ranges,
description=description,
rotate_secs=rotate_secs,
pre_fft=pretune,
tag_now=self.tag_now,
low_power_hold_down=(not pretune and low_power_hold_down),
slew_rx_time=slew_rx_time,
peak_fft_range=peak_fft_range,
)
sample_blocks = [retune_pre_fft] + self.get_dc_blocks(
correct_iq, dc_block_len, dc_block_long, fft_batch_size, nfft
)
pipeline_blocks = (
sample_blocks
+ fft_blocks
+ self.get_db_blocks(nfft, samp_rate, scaling)
+ [retune_fft]
)
return (
fft_batch_size,
[self.retune_pre_fft]
+ self.get_dc_blocks(
correct_iq, dc_block_len, dc_block_long, fft_batch_size, nfft
)
+ fft_blocks,
retune_pre_fft,
retune_fft,
pipeline_blocks[-1],
sample_blocks[-1],
pipeline_blocks,
)

def start(self):
Expand Down
36 changes: 20 additions & 16 deletions tests/test_grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,26 @@ def run_grscan_smoke(self, pretune, wavelearner, write_samples, test_file):
]
)
sdr = "file:" + sdr_file
tb = grscan(
freq_start=freq_start,
freq_end=freq_end,
sdr=sdr,
samp_rate=samp_rate,
tune_step_fft=512,
write_samples=write_samples,
sample_dir=tempdir,
iqtlabs=iqtlabs,
wavelearner=wavelearner,
rotate_secs=900,
db_clamp_floor=-1e6,
pretune=pretune,
fft_batch_size=4,
inference_output_dir=str(tempdir),
)
try:
tb = grscan(
freq_start=freq_start,
freq_end=freq_end,
sdr=sdr,
samp_rate=samp_rate,
tune_step_fft=512,
write_samples=write_samples,
sample_dir=tempdir,
iqtlabs=iqtlabs,
wavelearner=wavelearner,
rotate_secs=900,
db_clamp_floor=-1e6,
pretune=pretune,
fft_batch_size=4,
inference_output_dir=str(tempdir),
)
except Exception as e:
print(e)
raise
tb.start()
time.sleep(3)
tb.stop()
Expand Down

0 comments on commit 2cc2318

Please sign in to comment.