Skip to content

Commit

Permalink
Merge pull request #88 from IQTLabs/ltindall-label_refactor
Browse files Browse the repository at this point in the history
Ltindall label refactor
  • Loading branch information
ltindall authored Sep 13, 2024
2 parents d6c2eca + caa1641 commit c114e2e
Show file tree
Hide file tree
Showing 25 changed files with 5,169 additions and 633 deletions.
61 changes: 35 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,31 @@ In the labeling scripts, the settings for autolabeling need to be tuned for the

```python
annotation_utils.annotate(
f,
label="mavic3_video", # This is the label that is applied to all of the matching annotations
avg_window_len=256, # The number of samples over which to average signal power
avg_duration=0.25, # The number of seconds, from the start of the recording to use to automatically calculate the SNR threshold, if it is None then all of the samples will be used
debug=False,
set_bandwidth=10000000, # Manually set the bandwidth of the signals in Hz, if this parameter is set, then spectral_energy_threshold is ignored
spectral_energy_threshold=0.95, # Percentage used to determine the upper and lower frequency bounds for an annotation
force_threshold_db=-58, # Used to manually set the threshold used for detecting a signal and creating an annotation. If None, then the automatic threshold calculation will be used instead.
overwrite=False, # If True, any existing annotations in the .sigmf-meta file will be removed
min_bandwidth=16e6, # The minimum bandwidth (in Hz) of a signal to annotate
max_bandwidth=None, # The maximum bandwidth (in Hz) of a signal to annotate
min_annotation_length=10000, # The minimum numbers of samples in length a signal needs to be in order for it to be annotated. This is directly related to the sample rate a signal was captured at and does not take into account bandwidth. So 10000 samples at 20,000,000 samples per second, would mean a minimum transmission length of 0.0005 seconds
# max_annotations=500, # The maximum number of annotations to automatically add
dc_block=True # De-emphasize the DC spike when trying to calculate the frequencies for a signal
)
rfml.data.Data(filename),
avg_window_len=256, # The window size to use when averaging signal power
power_estimate_duration=0.1, # Process the file in chunks of power_estimate_duration seconds
debug_duration=0.25, # If debug==True, then plot debug_duration seconds of data in debug plots
debug=False, # Set True to enable debugging plots
verbose=False, # Set True to eanble verbose messages
dry_run=False, # Set True to disable annotations being written to SigMF-Meta file.
bandwidth_estimation=True, # If set to True, will estimate signal bandwidth using Gaussian Mixture Models. If set to a float will estimate signal bandwidth using spectral thresholding.
force_threshold_db=None, # Used to manually set the threshold used for detecting a signal and creating an annotation. If None, then the automatic threshold calculation will be used instead.
overwrite=True, # If True, any existing annotations in the .sigmf-meta file will be removed
max_annotations=None, # If set, limits the number of annotations to add.
dc_block=None, # De-emphasize the DC spike when trying to calculate the frequencies for a signal
time_start_stop=None, # Sets the start/stop time for annotating the recording (must be tuple or list of length 2).
n_components = None, # Sets the number of mixture components to use when calculating signal detection threshold. If not set, then automatically calculated from labels.
n_init=1, # Number of initializations to use in Gaussian Mixture Method. Increasing this number can significantly increase run time.
fft_len=256, # FFT length used in calculating bandwidth
labels = { # The labels dictionary defines the annotations that the script will attempt to find.
"mavic3_video": { # The dictionary keys define the annotation labels. Only a key is necessary.
"bandwidth_limits": (8e6, None), # Optional. Set min/max bandwidth limit for a signal. If None, no min/max limit.
"annotation_length": (10000, None), # Optional. Set min/max annoation length in number of samples. If None, no min/max limit.
"annotation_seconds": (0.0001, 0.0025), # Optional. Set min/max annotation length in seconds. If None, no min/max limit.
"set_bandwidth": (-8.5e6, 9.5e6) # Optional. Ignore bandwidth estimation, set bandwidth manually. Limits are in relation to center frequency.
}
}
)
```

### Tips for Tuning Autolabeling
Expand Down Expand Up @@ -138,7 +148,7 @@ After you have finished labeling your data, the next step is to train a model on

### Configure

This repo provides an automated script for training and evaluating models. To do this, configure the [run_experiments.py](rfml/run_experiments.py) file to point to the data you want to use and set the training parameters:
This repo provides an automated script for training and evaluating models. To do this, configure the [mixed_experiments.py](rfml/mixed_experiments.py) file or create your own to point to the data you want to use and set the training parameters:

```python
"experiment_0": { # A name to refer to the experiment
Expand All @@ -150,10 +160,10 @@ This repo provides an automated script for training and evaluating models. To do
}
```

Once you have the **run_experiments.py** file configured, run it:
Once you have the **mixed_experiments.py** file configured, run it:

```bash
python3 run_experiments.py
python3 mixed_experiments.py
```

Once the training has completed, it will print out the logs location, model accuracy, and the location of the best checkpoint:
Expand All @@ -170,18 +180,15 @@ Best Model Checkpoint: lightning_logs/version_5/checkpoints/experiment_logs/expe

### Convert & Export IQ Models

Once you have a trained model, you need to convert it into a portable format that can easily be served by TorchServe. To do this, use **convert_model.py**:
Once you have a trained model, you need to convert it into a portable format that can easily be served by TorchServe. To do this, use **export_model.py**:

```bash
python3 convert_model.py --model_name=drone_detect --checkpoint=lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt
python3 rfml/export_model.py --model_name=drone_detect --checkpoint=lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt
```
This will export a **_torchscript.pt** file.
This will create a **_torchscript.pt** and **_torchserve.pt** file in the weights folder.

```bash
torch-model-archiver --force --model-name drone_detect --version 1.0 --serialized-file weights/drone_detect_torchscript.pt --handler custom_handlers/iq_custom_handler.py --export-path models/ -r custom_handlers/requirements.txt
```
A **.mar** file will also be created in the [models/](./models/) folder. [GamutRF](https://github.com/IQTLabs/gamutRF) can run this model and use it to classify signals.

This will generate a **.mar** file in the [models/](./models/) folder. [GamutRF](https://github.com/IQTLabs/gamutRF) can run this model and use it to classify signals.

## Files

Expand All @@ -194,9 +201,11 @@ This will generate a **.mar** file in the [models/](./models/) folder. [GamutRF]

[experiment.py](rfml/experiment.py) - Class to manage experiments

[export_model.py](rfml/export_model.py) - Convert and export model checkpoints to Torchscript/Torchserve/MAR format.

[models.py](rfml/models.py) - Class for I/Q models (based on TorchSig)

[run_experiments.py](rfml/run_experiments.py) - Experiment configurations and run script
[experiments/](experiments/) - Experiment configurations and run script

[sigmf_pytorch_dataset.py](rfml/sigmf_pytorch_dataset.py) - PyTorch style dataset class for SigMF data (based on TorchSig)

Expand Down
139 changes: 139 additions & 0 deletions experiments/dji_mini2_wifi_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from rfml.experiment import *

# Ensure that data directories have sigmf-meta files with annotations
# Annotations can be generated using scripts in label_scripts directory or notebooks/Label_WiFi.ipynb and notebooks/Label_DJI.ipynb

spec_epochs = 0
iq_epochs = 10
iq_only_start_of_burst = False
iq_num_samples = 4000
iq_early_stop = 3
iq_train_limit = 0.01
iq_val_limit = 0.1

experiments = {
"experiment_nz_wifi_arl_mini2_pdx_mini2_to_leesburg_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-arl/01_30_23/mini2",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/iq_recordings",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/iq_recordings",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
"experiment_nz_wifi_arl_mini2_to_leesburg_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-arl/01_30_23/mini2",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/iq_recordings",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
"experiment_nz_wifi_pdx_mini2_to_leesburg_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/iq_recordings",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/iq_recordings",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
"experiment_nz_wifi_arl_mini2_to_pdx_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-arl/01_30_23/mini2",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/iq_recordings",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
"experiment_nz_wifi_leesburg_mini2_to_pdx_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/iq_recordings",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/iq_recordings",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
"experiment_nz_wifi_leesburg_mini2_pdx_mini2_to_arl_mini2": {
"class_list": ["mini2_video", "mini2_telem", "wifi"],
"train_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/iq_recordings",
"/data/s3_gamutrf/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/iq_recordings",
],
"val_dir": [
"/data/s3_gamutrf/gamutrf-nz-wifi",
"/data/s3_gamutrf/gamutrf-arl/01_30_23/mini2",
],
"iq_epochs": iq_epochs,
"spec_epochs": spec_epochs,
"iq_only_start_of_burst": iq_only_start_of_burst,
"iq_early_stop": iq_early_stop,
"iq_train_limit": iq_train_limit,
"iq_val_limit": iq_val_limit,
"notes": "",
},
}


if __name__ == "__main__":

experiments_to_run = [
# "experiment_nz_wifi_arl_mini2_pdx_mini2_to_leesburg_mini2",
# "experiment_nz_wifi_arl_mini2_to_leesburg_mini2",
# "experiment_nz_wifi_pdx_mini2_to_leesburg_mini2",
# "experiment_nz_wifi_arl_mini2_to_pdx_mini2",
# "experiment_nz_wifi_leesburg_mini2_to_pdx_mini2",
"experiment_nz_wifi_leesburg_mini2_pdx_mini2_to_arl_mini2"
]

train({name: experiments[name] for name in experiments_to_run})
51 changes: 1 addition & 50 deletions rfml/run_experiments.py → experiments/mixed_experiments.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from pathlib import Path


import torch

torch.set_float32_matmul_precision("high")


from rfml.experiment import *
from rfml.train_iq import *
from rfml.train_spec import *


# Ensure that data directories have sigmf-meta files with annotations
# Annotations can be generated using scripts in label_scripts directory or notebooks/Label_WiFi.ipynb and notebooks/Label_DJI.ipynb
Expand Down Expand Up @@ -277,47 +271,4 @@
"experiment_siggen",
]

for experiment_name in experiments_to_run:
print(f"Running {experiment_name}")
try:
exp = Experiment(
experiment_name=experiment_name, **experiments[experiment_name]
)

logs_timestamp = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")

if exp.iq_epochs > 0:
train_iq(
train_dataset_path=exp.train_dir,
val_dataset_path=exp.val_dir,
num_iq_samples=exp.iq_num_samples,
only_use_start_of_burst=exp.iq_only_start_of_burst,
epochs=exp.iq_epochs,
batch_size=exp.iq_batch_size,
class_list=exp.class_list,
output_dir=Path("experiment_logs", exp.experiment_name),
logs_dir=Path("iq_logs", logs_timestamp),
)
else:
print("Skipping IQ training")

if exp.spec_epochs > 0:
train_spec(
train_dataset_path=exp.train_dir,
val_dataset_path=exp.val_dir,
n_fft=exp.spec_n_fft,
time_dim=exp.spec_time_dim,
epochs=exp.spec_epochs,
batch_size=exp.spec_batch_size,
class_list=exp.class_list,
yolo_augment=exp.spec_yolo_augment,
skip_export=exp.spec_skip_export,
force_yolo_label_larger=exp.spec_force_yolo_label_larger,
output_dir=Path("experiment_logs", exp.experiment_name),
logs_dir=Path("spec_logs", logs_timestamp),
)
else:
print("Skipping spectrogram training")

except Exception as error:
print(f"Error: {error}")
train({name: experiments[name] for name in experiments_to_run})
37 changes: 37 additions & 0 deletions experiments/siggen_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

torch.set_float32_matmul_precision("medium")
from rfml.experiment import *

#
# python rfml/siggen_experiments.py
# python convert_model.py --model_name siggen_model --checkpoint /home/ltindall/iqt/rfml/lightning_logs/siggen_experiment/checkpoints/experiment_logs/siggen_experiment/iq_checkpoints/checkpoint-v3.ckpt
# torch-model-archiver --force --model-name siggen_model --version 1.0 --serialized-file rfml/weights/siggen_model_torchscript.pt --handler custom_handlers/iq_custom_handler.py --export-path models/ -r custom_handlers/requirements.txt
# cp models/siggen_model.mar ~/iqt/gamutrf-deploy/docker_rundir/model_store/
# sudo chmod -R 777 /home/ltindall/iqt/gamutrf-deploy/docker_rundir/
#


experiments = {
"siggen_experiment": {
"class_list": ["am", "fm"],
"train_dir": [
"/data/siggen/fm.sigmf-meta",
"/data/siggen/am.sigmf-meta",
],
"val_dir": [
"/data/siggen/fm.sigmf-meta",
"/data/siggen/am.sigmf-meta",
],
"iq_epochs": 10,
"iq_train_limit": 0.5,
"iq_only_start_of_burst": False,
"iq_num_samples": 1024,
"spec_epochs": 0,
}
}


if __name__ == "__main__":

train(experiments)
Loading

0 comments on commit c114e2e

Please sign in to comment.