Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
grazder committed Mar 14, 2024
1 parent dfa1227 commit 04d035c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 5 additions & 2 deletions torchDF/model_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,13 @@ def custom_identity(g: jit_utils.GraphContext, X):


def main(args):
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if args.minimal:
streaming_pipeline = TorchDFMinimalPipeline(device="cpu")
else:
streaming_pipeline = TorchDFPipeline(device="cpu")
streaming_pipeline = TorchDFPipeline(device="cpu", always_apply_all_stages=True)

frame_size = streaming_pipeline.hop_size
input_names = streaming_pipeline.input_names
Expand Down Expand Up @@ -212,7 +215,7 @@ def main(args):
input_shapes_dict = {x: y.shape for x, y in input_features_onnx.items()}

# Simplify not working for not minimal!
if args.simplify and args.minimal:
if args.simplify:
# raise NotImplementedError("Simplify not working for flatten states!")
onnx_simplify(args.output_path, input_features_onnx, input_shapes_dict)
logger.info(f"Model simplified! {args.output_path}")
Expand Down
4 changes: 1 addition & 3 deletions torchDF/torch_df_streaming_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
from df import init_df

from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from loguru import logger

from df.config import Csv, DfParams, config
from df.modules import Conv2dNormAct, ConvTranspose2dNormAct

Expand Down

0 comments on commit 04d035c

Please sign in to comment.