From 0e438997a571523291415ae5a7b94ac9f0e6f2ce Mon Sep 17 00:00:00 2001 From: ltindall Date: Sun, 15 Sep 2024 12:57:56 -0700 Subject: [PATCH] Add class mapping to custom_handler. Save index_to_name.json in mar file. Update train_iq.py to simplify logging structure. --- .gitignore | 2 +- README.md | 4 ++-- custom_handlers/iq_custom_handler.py | 35 ++++++++++++++++++++++------ rfml/export_model.py | 21 +++++++++++++++-- rfml/train_iq.py | 30 ++++++++++++++++++++++-- 5 files changed, 78 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 6c5b72f..f51a867 100644 --- a/.gitignore +++ b/.gitignore @@ -16,5 +16,5 @@ experiment_logs/ lightning_logs/ spec_logs/ - +tensorboard_logs/ rfml-dev/.README.md.swp diff --git a/README.md b/README.md index 1619d60..ed97b77 100644 --- a/README.md +++ b/README.md @@ -175,7 +175,7 @@ I/Q TRAINING COMPLETE Find results in experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32 Total Accuracy: 98.10% -Best Model Checkpoint: lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt +Best Model Checkpoint: experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/checkpoints/checkpoint.ckpt ``` ### Convert & Export IQ Models @@ -183,7 +183,7 @@ Best Model Checkpoint: lightning_logs/version_5/checkpoints/experiment_logs/expe Once you have a trained model, you need to convert it into a portable format that can easily be served by TorchServe. To do this, use **export_model.py**: ```bash -python3 rfml/export_model.py --model_name=drone_detect --checkpoint=lightning_logs/version_5/checkpoints/experiment_logs/experiment_1/iq_checkpoints/checkpoint.ckpt +python3 rfml/export_model.py --model_name=drone_detect --checkpoint=experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/checkpoints/checkpoint.ckpt --index_to_name=experiment_logs/experiment_1/iq_logs/08_08_2024_09_17_32/index_to_name.json ``` This will create a **_torchscript.pt** and **_torchserve.pt** file in the weights folder. diff --git a/custom_handlers/iq_custom_handler.py b/custom_handlers/iq_custom_handler.py index 848a03e..8c143fd 100644 --- a/custom_handlers/iq_custom_handler.py +++ b/custom_handlers/iq_custom_handler.py @@ -3,10 +3,13 @@ import os import time import torch +import logging from collections import defaultdict from ts.torch_handler.base_handler import BaseHandler +from ts.utils.util import load_label_mapping + try: import torch_xla.core.xla_model as xm @@ -39,8 +42,10 @@ def initialize(self, context): :param context: Initial context contains model server system properties. :return: """ + logging.info(f"\n\nStarting iq_custom_handler...\n\n") if torch.cuda.is_available(): self.device = torch.device("cuda") + print(f"{self.handler_name}: using CUDA") elif XLA_AVAILABLE: self.device = xm.xla_device() @@ -56,6 +61,15 @@ def initialize(self, context): self.manifest = context.manifest model_dir = context.system_properties.get("model_dir") + + mapping_file_path = os.path.join(model_dir, "index_to_name.json") + # print(f"{os.listdir('.')=}") + # print(f"{os.listdir(model_dir)=}") + + if not os.path.exists(mapping_file_path): + raise ValueError + self.mapping = load_label_mapping(mapping_file_path) + self.model_pt_path = None if "serializedFile" in self.manifest["model"]: serialized_file = self.manifest["model"]["serializedFile"] @@ -170,13 +184,14 @@ def gate(self, data): self.add_to_avg(avg_pwr) - print("\n=====================================\n") - print("\n=====================================\n") - print(f"\n{data=}\n") + print("\n=====================================") + print("=====================================\n") + print(f"GATE") + print(f"\n{data.shape=}\n") print(f"\n{avg_pwr=}, \n{self.max_db=}, \n{self.avg_db_historical=}\n") print(f"\n{torch.min(torch.abs(data)**2)=}, {torch.max(torch.abs(data)**2)=}\n") - print("\n=====================================\n") - print("\n=====================================\n") + print("\n=====================================") + print("=====================================") if avg_pwr > (self.max_db + self.avg_db_historical) / 2: return False @@ -197,6 +212,7 @@ def preprocess(self, data): data = torch.tensor(np.frombuffer(body, dtype=np.complex64), dtype=torch.cfloat) print("\n=====================================\n") + print(f"PREPROCESS") print(f"\n{data=}\n") print(f"\n{torch.min(torch.abs(data)**2)=}, {torch.max(torch.abs(data)**2)=}\n") avg_pwr = torch.mean(torch.abs(data) ** 2) @@ -209,7 +225,7 @@ def preprocess(self, data): # data should be of size (N, 2, n_samples) data = data.to(self.device) - print("\n=====================================\n") + print("\n=====================================") return data def inference(self, model_input): @@ -228,12 +244,17 @@ def postprocess(self, inference_output): :param inference_output: list of inference output :return: list of predict results """ + print("\n=====================================\n") + print(f"POSTPROCESS") + # print(f"{self.mapping=}") + confidences, class_indexes = torch.max(inference_output.data, 1) results = { - str(class_index): [{"confidence": confidence}] + self.mapping[str(class_index)]: [{"confidence": confidence}] for class_index, confidence in zip( class_indexes.tolist(), confidences.tolist() ) } print(f"\n{inference_output=}\n{results=}\n") + print("\n=====================================") return [results] diff --git a/rfml/export_model.py b/rfml/export_model.py index d133c88..567f8db 100644 --- a/rfml/export_model.py +++ b/rfml/export_model.py @@ -4,6 +4,7 @@ import os import argparse import subprocess +from pathlib import Path def argument_parser(): @@ -27,6 +28,12 @@ def argument_parser(): default="custom_handlers/iq_custom_handler.py", help="Custom handler to use when exporting to MAR. Only used if --mode='export'. (default: %(default)s)", ) + parser.add_argument( + "--index_to_name", + type=str, + required=True, + help="Path of JSON file defining mapping of label index to name.", + ) parser.add_argument( "--export_path", type=str, @@ -75,7 +82,9 @@ def convert_model(model_name, checkpoint): return torchscript_file -def export_model(model_name, torchscript_file, custom_handler, export_path): +def export_model( + model_name, torchscript_file, custom_handler, index_to_name, export_path +): torch_model_archiver_args = [ "torch-model-archiver", @@ -88,12 +97,16 @@ def export_model(model_name, torchscript_file, custom_handler, export_path): torchscript_file, "--handler", custom_handler, + "--extra-files", + index_to_name, "--export-path", export_path, "-r", "custom_handlers/requirements.txt", ] + print(f"Saving Torchserve MAR to {str(Path(export_path, model_name+'.mar'))}") + subprocess.run(torch_model_archiver_args) @@ -105,5 +118,9 @@ def export_model(model_name, torchscript_file, custom_handler, export_path): if args.mode == "export": export_model( - args.model_name, torchscript_file, args.custom_handler, args.export_path + args.model_name, + torchscript_file, + args.custom_handler, + args.index_to_name, + args.export_path, ) diff --git a/rfml/train_iq.py b/rfml/train_iq.py index 59742c6..bc8b2e2 100644 --- a/rfml/train_iq.py +++ b/rfml/train_iq.py @@ -19,6 +19,7 @@ from datetime import datetime import numpy as np import os +import json from pathlib import Path from torchsig.models.iq_models.efficientnet.efficientnet import ( @@ -48,6 +49,7 @@ import os from rfml.sigmf_pytorch_dataset import SigMFDataset from rfml.models import ExampleNetwork, SimpleRealNet +from rfml.export_model import * from torchsig.transforms import ( Compose, @@ -287,17 +289,29 @@ def train_iq( ) # Setup checkpoint callbacks - checkpoint_filename = f"{str(output_dir)}/iq_checkpoints/checkpoint" + checkpoint_filename = f"checkpoints/checkpoint" checkpoint_callback = ModelCheckpoint( + dirpath=logs_dir, filename=checkpoint_filename, save_top_k=True, monitor="val_loss", mode="min", ) + + index_to_name_file = Path( + logs_dir, "index_to_name.json" + ) # f"lightning_logs/{experiment_name}/index_to_name.json" + index_to_name = {i: class_list[i] for i in range(len(class_list))} + index_to_name_object = json.dumps(index_to_name, indent=4) + with open(index_to_name_file, "w") as outfile: + outfile.write(index_to_name_object) + # Create and fit trainer experiment_name = experiment_name if experiment_name else 1 logger = TensorBoardLogger( - save_dir=os.getcwd(), version=experiment_name, name="lightning_logs" + save_dir="tensorboard_logs", + # version=experiment_name, + name=experiment_name, # "lightning_logs" ) trainer = Trainer( max_epochs=epochs, @@ -311,6 +325,7 @@ def train_iq( devices=1, logger=logger, # profiler="simple", + default_root_dir=logs_dir, ) print(f"\nStarting training...") trainer.fit(example_model) @@ -380,6 +395,17 @@ def train_iq( print(f"Total Accuracy: {acc*100:.2f}%") print(f"Best Model Checkpoint: {checkpoint_callback.best_model_path}") + torchscript_file = convert_model( + experiment_name, checkpoint_callback.best_model_path + ) + export_model( + experiment_name, + torchscript_file, + "custom_handlers/iq_custom_handler.py", + index_to_name_file, + "models/", + ) + def visualize_dataset( dataset_path,