Skip to content

Commit

Permalink
Merge pull request #91 from ltindall/class_map
Browse files Browse the repository at this point in the history
Add class mapping to custom_handler. Save index_to_name.json in mar f…
  • Loading branch information
ltindall authored Sep 15, 2024
2 parents f46c50c + 0e43899 commit 48126d9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
experiment_logs/
lightning_logs/
spec_logs/

tensorboard_logs/
rfml-dev/.README.md.swp
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ I/Q TRAINING COMPLETE
Find results in experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32

Total Accuracy: 98.10%
Best Model Checkpoint: lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt
Best Model Checkpoint: experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/checkpoints/checkpoint.ckpt
```

### Convert & Export IQ Models

Once you have a trained model, you need to convert it into a portable format that can easily be served by TorchServe. To do this, use **export_model.py**:

```bash
python3 rfml/export_model.py --model_name=drone_detect --checkpoint=lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt
python3 rfml/export_model.py --model_name=drone_detect --checkpoint=experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/checkpoints/checkpoint.ckpt --index_to_name=experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/index_to_name.json
```
This will create a **_torchscript.pt** and **_torchserve.pt** file in the weights folder.

Expand Down
35 changes: 28 additions & 7 deletions custom_handlers/iq_custom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import os
import time
import torch
import logging

from collections import defaultdict
from ts.torch_handler.base_handler import BaseHandler

from ts.utils.util import load_label_mapping

try:
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -39,8 +42,10 @@ def initialize(self, context):
:param context: Initial context contains model server system properties.
:return:
"""
logging.info(f"\n\nStarting iq_custom_handler...\n\n")
if torch.cuda.is_available():
self.device = torch.device("cuda")

print(f"{self.handler_name}: using CUDA")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
Expand All @@ -56,6 +61,15 @@ def initialize(self, context):

self.manifest = context.manifest
model_dir = context.system_properties.get("model_dir")

mapping_file_path = os.path.join(model_dir, "index_to_name.json")
# print(f"{os.listdir('.')=}")
# print(f"{os.listdir(model_dir)=}")

if not os.path.exists(mapping_file_path):
raise ValueError
self.mapping = load_label_mapping(mapping_file_path)

self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
Expand Down Expand Up @@ -170,13 +184,14 @@ def gate(self, data):

self.add_to_avg(avg_pwr)

print("\n=====================================\n")
print("\n=====================================\n")
print(f"\n{data=}\n")
print("\n=====================================")
print("=====================================\n")
print(f"GATE")
print(f"\n{data.shape=}\n")
print(f"\n{avg_pwr=}, \n{self.max_db=}, \n{self.avg_db_historical=}\n")
print(f"\n{torch.min(torch.abs(data)**2)=}, {torch.max(torch.abs(data)**2)=}\n")
print("\n=====================================\n")
print("\n=====================================\n")
print("\n=====================================")
print("=====================================")

if avg_pwr > (self.max_db + self.avg_db_historical) / 2:
return False
Expand All @@ -197,6 +212,7 @@ def preprocess(self, data):

data = torch.tensor(np.frombuffer(body, dtype=np.complex64), dtype=torch.cfloat)
print("\n=====================================\n")
print(f"PREPROCESS")
print(f"\n{data=}\n")
print(f"\n{torch.min(torch.abs(data)**2)=}, {torch.max(torch.abs(data)**2)=}\n")
avg_pwr = torch.mean(torch.abs(data) ** 2)
Expand All @@ -209,7 +225,7 @@ def preprocess(self, data):
# data should be of size (N, 2, n_samples)

data = data.to(self.device)
print("\n=====================================\n")
print("\n=====================================")
return data

def inference(self, model_input):
Expand All @@ -228,12 +244,17 @@ def postprocess(self, inference_output):
:param inference_output: list of inference output
:return: list of predict results
"""
print("\n=====================================\n")
print(f"POSTPROCESS")
# print(f"{self.mapping=}")

confidences, class_indexes = torch.max(inference_output.data, 1)
results = {
str(class_index): [{"confidence": confidence}]
self.mapping[str(class_index)]: [{"confidence": confidence}]
for class_index, confidence in zip(
class_indexes.tolist(), confidences.tolist()
)
}
print(f"\n{inference_output=}\n{results=}\n")
print("\n=====================================")
return [results]
21 changes: 19 additions & 2 deletions rfml/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import argparse
import subprocess
from pathlib import Path


def argument_parser():
Expand All @@ -27,6 +28,12 @@ def argument_parser():
default="custom_handlers/iq_custom_handler.py",
help="Custom handler to use when exporting to MAR. Only used if --mode='export'. (default: %(default)s)",
)
parser.add_argument(
"--index_to_name",
type=str,
required=True,
help="Path of JSON file defining mapping of label index to name.",
)
parser.add_argument(
"--export_path",
type=str,
Expand Down Expand Up @@ -75,7 +82,9 @@ def convert_model(model_name, checkpoint):
return torchscript_file


def export_model(model_name, torchscript_file, custom_handler, export_path):
def export_model(
model_name, torchscript_file, custom_handler, index_to_name, export_path
):

torch_model_archiver_args = [
"torch-model-archiver",
Expand All @@ -88,12 +97,16 @@ def export_model(model_name, torchscript_file, custom_handler, export_path):
torchscript_file,
"--handler",
custom_handler,
"--extra-files",
index_to_name,
"--export-path",
export_path,
"-r",
"custom_handlers/requirements.txt",
]

print(f"Saving Torchserve MAR to {str(Path(export_path, model_name+'.mar'))}")

subprocess.run(torch_model_archiver_args)


Expand All @@ -105,5 +118,9 @@ def export_model(model_name, torchscript_file, custom_handler, export_path):

if args.mode == "export":
export_model(
args.model_name, torchscript_file, args.custom_handler, args.export_path
args.model_name,
torchscript_file,
args.custom_handler,
args.index_to_name,
args.export_path,
)
30 changes: 28 additions & 2 deletions rfml/train_iq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
import numpy as np
import os
import json
from pathlib import Path

from torchsig.models.iq_models.efficientnet.efficientnet import (
Expand Down Expand Up @@ -48,6 +49,7 @@
import os
from rfml.sigmf_pytorch_dataset import SigMFDataset
from rfml.models import ExampleNetwork, SimpleRealNet
from rfml.export_model import *

from torchsig.transforms import (
Compose,
Expand Down Expand Up @@ -287,17 +289,29 @@ def train_iq(
)

# Setup checkpoint callbacks
checkpoint_filename = f"{str(output_dir)}/iq_checkpoints/checkpoint"
checkpoint_filename = f"checkpoints/checkpoint"
checkpoint_callback = ModelCheckpoint(
dirpath=logs_dir,
filename=checkpoint_filename,
save_top_k=True,
monitor="val_loss",
mode="min",
)

index_to_name_file = Path(
logs_dir, "index_to_name.json"
) # f"lightning_logs/{experiment_name}/index_to_name.json"
index_to_name = {i: class_list[i] for i in range(len(class_list))}
index_to_name_object = json.dumps(index_to_name, indent=4)
with open(index_to_name_file, "w") as outfile:
outfile.write(index_to_name_object)

# Create and fit trainer
experiment_name = experiment_name if experiment_name else 1
logger = TensorBoardLogger(
save_dir=os.getcwd(), version=experiment_name, name="lightning_logs"
save_dir="tensorboard_logs",
# version=experiment_name,
name=experiment_name, # "lightning_logs"
)
trainer = Trainer(
max_epochs=epochs,
Expand All @@ -311,6 +325,7 @@ def train_iq(
devices=1,
logger=logger,
# profiler="simple",
default_root_dir=logs_dir,
)
print(f"\nStarting training...")
trainer.fit(example_model)
Expand Down Expand Up @@ -380,6 +395,17 @@ def train_iq(
print(f"Total Accuracy: {acc*100:.2f}%")
print(f"Best Model Checkpoint: {checkpoint_callback.best_model_path}")

torchscript_file = convert_model(
experiment_name, checkpoint_callback.best_model_path
)
export_model(
experiment_name,
torchscript_file,
"custom_handlers/iq_custom_handler.py",
index_to_name_file,
"models/",
)


def visualize_dataset(
dataset_path,
Expand Down

0 comments on commit 48126d9

Please sign in to comment.