Skip to content

Commit

Permalink
export update
Browse files Browse the repository at this point in the history
  • Loading branch information
a.korepanov committed Mar 13, 2024
1 parent 9242cf5 commit 5c015ea
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 64 deletions.
2 changes: 1 addition & 1 deletion torchDF/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ poetry run python torch_df_streaming_minimal.py --audio-path examples/A1CIM28ZUC

To convert model to onnx and run tests:
```
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling --minimal
```

TODO:
Expand Down
107 changes: 50 additions & 57 deletions torchDF/model_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,14 @@
import torch.utils.benchmark as benchmark

from torch_df_streaming_minimal import TorchDFMinimalPipeline
from torch_df_streaming import TorchDFPipeline
from typing import Dict, Iterable
from torch.onnx._internal import jit_utils
from loguru import logger

torch.manual_seed(0)

FRAME_SIZE = 480
OPSET_VERSION = 17
INPUT_NAMES = [
"input_frame",
"erb_norm_state",
"band_unit_norm_state",
"analysis_mem",
"synthesis_mem",
"rolling_erb_buf",
"rolling_feat_spec_buf",
"rolling_c0_buf",
"rolling_spec_buf_x",
"rolling_spec_buf_y",
"enc_hidden",
"erb_dec_hidden",
"df_dec_hidden",
]
OUTPUT_NAMES = [
"enhanced_audio_frame",
"new_erb_norm_state",
"new_band_unit_norm_state",
"new_analysis_mem",
"new_synthesis_mem",
"new_rolling_erb_buf",
"new_rolling_feat_spec_buf",
"new_rolling_c0_buf",
"new_rolling_spec_buf_x",
"new_rolling_spec_buf_y",
"new_enc_hidden",
"new_erb_dec_hidden",
"new_df_dec_hidden",
]


def onnx_simplify(
Expand Down Expand Up @@ -78,7 +48,9 @@ def onnx_simplify(
return path


def test_onnx_model(torch_model, ort_session, states):
def test_onnx_model(
torch_model, ort_session, states, frame_size, input_names, output_names
):
"""
Simple test that everything converted correctly
Expand All @@ -91,32 +63,29 @@ def test_onnx_model(torch_model, ort_session, states):
states_onnx = copy.deepcopy(states)

for i in range(30):
input_frame = torch.randn(FRAME_SIZE)
input_frame = torch.randn(frame_size)

# torch
output_torch = torch_model(input_frame, *states_torch)

# onnx
output_onnx = ort_session.run(
OUTPUT_NAMES,
generate_onnx_features([input_frame, *states_onnx]),
output_names,
generate_onnx_features([input_frame, *states_onnx], input_names),
)

for x, y, name in zip(output_torch, output_onnx, OUTPUT_NAMES):
for x, y, name in zip(output_torch, output_onnx, output_names):
y_tensor = torch.from_numpy(y)
assert torch.allclose(
x, y_tensor, atol=1e-2
), f"out {name} - {i}, {x.flatten()[-5:]}, {y_tensor.flatten()[-5:]}"


def generate_onnx_features(input_features):
return {x: y.detach().cpu().numpy() for x, y in zip(INPUT_NAMES, input_features)}
def generate_onnx_features(input_features, input_names):
return {x: y.detach().cpu().numpy() for x, y in zip(input_names, input_features)}


def perform_benchmark(
ort_session,
input_features: Dict[str, np.ndarray],
):
def perform_benchmark(ort_session, input_features: Dict[str, np.ndarray], output_names):
"""
Benchmark ONNX model performance
Expand All @@ -127,7 +96,7 @@ def perform_benchmark(

def run_onnx():
output = ort_session.run(
OUTPUT_NAMES,
output_names,
input_features,
)

Expand All @@ -141,16 +110,18 @@ def run_onnx():
)


def infer_onnx_model(streaming_pipeline, ort_session, inference_path):
def infer_onnx_model(
streaming_pipeline, ort_session, inference_path, input_names, output_names
):
"""
Inference ONNX model with TorchDFPipeline
"""
del streaming_pipeline.torch_streaming_model
streaming_pipeline.torch_streaming_model = lambda *features: (
torch.from_numpy(x)
for x in ort_session.run(
OUTPUT_NAMES,
generate_onnx_features(list(features)),
output_names,
generate_onnx_features(list(features), input_names),
)
)

Expand Down Expand Up @@ -196,11 +167,19 @@ def custom_identity(g: jit_utils.GraphContext, X):


def main(args):
streaming_pipeline = TorchDFMinimalPipeline(device="cpu")
if args.minimal:
streaming_pipeline = TorchDFMinimalPipeline(device="cpu")
else:
streaming_pipeline = TorchDFPipeline(device="cpu")

frame_size = streaming_pipeline.hop_size
input_names = streaming_pipeline.input_names
output_names = streaming_pipeline.output_names

torch_df = streaming_pipeline.torch_streaming_model
states = streaming_pipeline.states

input_frame = torch.rand(FRAME_SIZE)
input_frame = torch.rand(frame_size)
input_features = (input_frame, *states)
torch_df(*input_features) # check model

Expand All @@ -223,17 +202,17 @@ def main(args):
input_features,
args.output_path,
verbose=False,
input_names=INPUT_NAMES,
output_names=OUTPUT_NAMES,
input_names=input_names,
output_names=output_names,
opset_version=OPSET_VERSION,
)
logger.info(f"Model exported to {args.output_path}!")

input_features_onnx = generate_onnx_features(input_features)
input_features_onnx = generate_onnx_features(input_features, input_names)
input_shapes_dict = {x: y.shape for x, y in input_features_onnx.items()}

# Simplify not working!
if args.simplify:
# Simplify not working for not minimal!
if args.simplify and args.minimal:
# raise NotImplementedError("Simplify not working for flatten states!")
onnx_simplify(args.output_path, input_features_onnx, input_shapes_dict)
logger.info(f"Model simplified! {args.output_path}")
Expand Down Expand Up @@ -274,7 +253,7 @@ def main(args):

for _ in range(3):
onnx_outputs = ort_session.run(
OUTPUT_NAMES,
output_names,
input_features_onnx,
)

Expand All @@ -288,15 +267,28 @@ def main(args):

if args.test:
logger.info("Testing...")
test_onnx_model(torch_df, ort_session, input_features[1:])
test_onnx_model(
torch_df,
ort_session,
input_features[1:],
frame_size,
input_names,
output_names,
)
logger.info("Tests passed!")

if args.performance:
logger.info("Performanse check...")
perform_benchmark(ort_session, input_features_onnx)
perform_benchmark(ort_session, input_features_onnx, output_names)

if args.inference_path:
infer_onnx_model(streaming_pipeline, ort_session, args.inference_path)
infer_onnx_model(
streaming_pipeline,
ort_session,
args.inference_path,
input_names,
output_names,
)
logger.info(f"Audio from {args.inference_path} enhanced!")


Expand All @@ -318,4 +310,5 @@ def main(args):
parser.add_argument("--inference-path", type=str, help="Run inference on example")
parser.add_argument("--ort", action="store_true", help="Save to ort format")
parser.add_argument("--profiling", action="store_true", help="Run ONNX profiler")
parser.add_argument("--minimal", action="store_true", help="Export minimal version")
main(parser.parse_args())
16 changes: 10 additions & 6 deletions torchDF/torch_df_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,14 @@ def __init__(
sr=self.sample_rate,
)
self.torch_streaming_model = self.torch_streaming_model.to(device)
self.states = torch.zeros(
self.torch_streaming_model.states_full_len, device=device
)

self.atten_lim_db = torch.tensor(atten_lim_db, device=device)
self.states = [
torch.zeros(self.torch_streaming_model.states_full_len, device=device),
self.atten_lim_db,
]

self.input_names = ["input_frame", "states", "atten_lim_db"]
self.output_names = ["enhanced_audio_frame", "new_states", "lsnr"]

def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor:
"""
Expand Down Expand Up @@ -835,9 +838,10 @@ def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor:
output_frames = []

for input_frame in chunked_audio:
(enhanced_audio_frame, self.states, lsnr) = self.torch_streaming_model(
input_frame, self.states, self.atten_lim_db
enhanced_audio_frame, states, _ = self.torch_streaming_model(
input_frame, *self.states
)
self.states = [states, self.atten_lim_db]

output_frames.append(enhanced_audio_frame)

Expand Down
30 changes: 30 additions & 0 deletions torchDF/torch_df_streaming_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,36 @@ def __init__(
erb_dec_hidden,
df_dec_hidden,
]
self.input_names = [
"input_frame",
"erb_norm_state",
"band_unit_norm_state",
"analysis_mem",
"synthesis_mem",
"rolling_erb_buf",
"rolling_feat_spec_buf",
"rolling_c0_buf",
"rolling_spec_buf_x",
"rolling_spec_buf_y",
"enc_hidden",
"erb_dec_hidden",
"df_dec_hidden",
]
self.output_names = [
"enhanced_audio_frame",
"new_erb_norm_state",
"new_band_unit_norm_state",
"new_analysis_mem",
"new_synthesis_mem",
"new_rolling_erb_buf",
"new_rolling_feat_spec_buf",
"new_rolling_c0_buf",
"new_rolling_spec_buf_x",
"new_rolling_spec_buf_y",
"new_enc_hidden",
"new_erb_dec_hidden",
"new_df_dec_hidden",
]

def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor:
"""
Expand Down

0 comments on commit 5c015ea

Please sign in to comment.