Skip to content

Commit

Permalink
Multiprocess pool refactored (#583)
Browse files Browse the repository at this point in the history
* [skip actions] [multiproc] 2024-07-14T09:25:19+03:00

* yapf style fix
  • Loading branch information
babenek authored Jul 16, 2024
1 parent 9acea8b commit 9800cb6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
- name: Run CredSweeper tool
run: |
credsweeper --banner --jobs $(nproc) --path data --save-json report.${{ github.event.pull_request.head.sha }}.json | tee credsweeper.${{ github.event.pull_request.head.sha }}.log
credsweeper --banner --log info --jobs $(nproc) --path data --save-json report.${{ github.event.pull_request.head.sha }}.json | tee credsweeper.${{ github.event.pull_request.head.sha }}.log
- name: Run Benchmark
run: |
Expand Down
36 changes: 23 additions & 13 deletions credsweeper/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import itertools
import logging
import multiprocessing
import signal
import sys
from pathlib import Path
from typing import Any, List, Optional, Union, Dict, Sequence, Tuple

Expand Down Expand Up @@ -253,10 +251,7 @@ def scan(self, content_providers: Sequence[Union[DiffContentProvider, TextConten

def __single_job_scan(self, content_providers: Sequence[Union[DiffContentProvider, TextContentProvider]]) -> None:
"""Performs scan in main thread"""
all_cred: List[Candidate] = []
for i in content_providers:
candidates = self.file_scan(i)
all_cred.extend(candidates)
all_cred = self.files_scan(content_providers)
if self.config.api_validation:
api_validation = ApplyValidation()
for cred in all_cred:
Expand All @@ -278,24 +273,39 @@ def __multi_jobs_scan(self, content_providers: Sequence[Union[DiffContentProvide
if "SILENCE" == self.__log_level:
logging.addLevelName(60, "SILENCE")
log_kwargs["level"] = self.__log_level
# providers_map: List[Sequence[Union[DiffContentProvider, TextContentProvider]]] = \
# [content_providers[x::self.pool_count] for x in range(self.pool_count)]
with multiprocessing.get_context("spawn").Pool(processes=self.pool_count,
initializer=self.pool_initializer,
initargs=(log_kwargs, )) as pool:
try:
# Get list credentials for each file
scan_results_per_file = pool.map(self.file_scan, content_providers)
# Join all sublist into a single list
scan_results = list(itertools.chain(*scan_results_per_file))
for cred in scan_results:
self.credential_manager.add_credential(cred)
for scan_results in pool.imap_unordered(self.files_scan, (content_providers[x::self.pool_count]
for x in range(self.pool_count))):
for cred in scan_results:
self.credential_manager.add_credential(cred)
if self.config.api_validation:
logger.info("Run API Validation")
api_validation = ApplyValidation()
api_validation.validate_credentials(pool, self.credential_manager)
except KeyboardInterrupt:
pool.terminate()
pool.join()
sys.exit()
raise
pool.close()
pool.join()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def files_scan(
self, #
content_providers: Sequence[Union[DiffContentProvider, TextContentProvider]]) -> List[Candidate]:
"""Auxiliary method for scan one sequence"""
all_cred: List[Candidate] = []
for i in content_providers:
candidates = self.file_scan(i)
all_cred.extend(candidates)
logger.info(f"Completed: processed {len(content_providers)} providers with {len(all_cred)} candidates")
return all_cred

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

Expand Down

0 comments on commit 9800cb6

Please sign in to comment.