diff --git a/bot.py b/bot.py index bd7636e..c9a3b1f 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,5 @@ import os +import signal import sys import traceback from typing import Union, Optional @@ -6,12 +7,27 @@ import asyncio import uuid import json -from nio import (AsyncClient, AsyncClientConfig, InviteMemberEvent, JoinError, - KeyVerificationCancel, KeyVerificationEvent, DownloadError, - KeyVerificationKey, KeyVerificationMac, KeyVerificationStart, - LocalProtocolError, LoginResponse, MatrixRoom, MegolmEvent, - RoomMessageAudio, RoomEncryptedAudio, ToDeviceError, crypto, - EncryptionError) +from nio import ( + AsyncClient, + AsyncClientConfig, + InviteMemberEvent, + JoinError, + KeyVerificationCancel, + KeyVerificationEvent, + DownloadError, + KeyVerificationKey, + KeyVerificationMac, + KeyVerificationStart, + LocalProtocolError, + LoginResponse, + MatrixRoom, + MegolmEvent, + RoomMessageAudio, + RoomEncryptedAudio, + ToDeviceError, + crypto, + EncryptionError, +) from nio.store.database import SqliteStore from faster_whisper import WhisperModel @@ -40,12 +56,11 @@ def __init__( num_workers: int = 1, download_root: str = "models", ): - if (homeserver is None or user_id is None - or device_id is None): + if homeserver is None or user_id is None or device_id is None: logger.warning("homeserver && user_id && device_id is required") sys.exit(1) - if (password is None and access_token is None): + if password is None and access_token is None: logger.warning("password or access_toekn is required") sys.exit(1) @@ -87,26 +102,36 @@ def __init__( # initialize AsyncClient object self.store_path = os.getcwd() - self.config = AsyncClientConfig(store=SqliteStore, - store_name="db", - store_sync_tokens=True, - encryption_enabled=True, - ) - self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id, - config=self.config, store_path=self.store_path,) + self.config = AsyncClientConfig( + store=SqliteStore, + store_name="db", + store_sync_tokens=True, + encryption_enabled=True, + ) + self.client = AsyncClient( + homeserver=self.homeserver, + user=self.user_id, + device_id=self.device_id, + config=self.config, + store_path=self.store_path, + ) if self.access_token is not None: self.client.access_token = self.access_token # setup event callbacks self.client.add_event_callback( - self.message_callback, (RoomMessageAudio, RoomEncryptedAudio, )) - self.client.add_event_callback( - self.decryption_failure, (MegolmEvent, )) - self.client.add_event_callback( - self.invite_callback, (InviteMemberEvent, )) + self.message_callback, + ( + RoomMessageAudio, + RoomEncryptedAudio, + ), + ) + self.client.add_event_callback(self.decryption_failure, (MegolmEvent,)) + self.client.add_event_callback(self.invite_callback, (InviteMemberEvent,)) self.client.add_to_device_callback( - self.to_device_callback, (KeyVerificationEvent, )) + self.to_device_callback, (KeyVerificationEvent,) + ) # intialize whisper model self.model = WhisperModel( @@ -115,23 +140,19 @@ def __init__( compute_type=self.compute_type, cpu_threads=self.cpu_threads, num_workers=self.num_workers, - download_root=self.download_root,) - - def __del__(self): - try: - loop = asyncio.get_running_loop() - except RuntimeError as e: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._close()) + download_root=self.download_root, + ) - async def _close(self): + async def close(self, task: asyncio.Task = None) -> None: await self.client.close() - logger.info("Bot stopped!") + task.cancel() + logger.info("Bot closed!") + + # message_callback event - # message_callback event - async def message_callback(self, room: MatrixRoom, - event: Union[RoomMessageAudio, RoomEncryptedAudio]) -> None: + async def message_callback( + self, room: MatrixRoom, event: Union[RoomMessageAudio, RoomEncryptedAudio] + ) -> None: if self.room_id is None: room_id = room.room_id else: @@ -178,12 +199,8 @@ async def message_callback(self, room: MatrixRoom, await f.write( crypto.attachments.decrypt_attachment( media_data, - event.source["content"]["file"]["key"][ - "k" - ], - event.source["content"]["file"]["hashes"][ - "sha256" - ], + event.source["content"]["file"]["key"]["k"], + event.source["content"]["file"]["hashes"]["sha256"], event.source["content"]["file"]["iv"], ) ) @@ -216,8 +233,8 @@ async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None return logger.error( - f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" + - "Please make sure the bot current session is verified" + f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" + + "Please make sure the bot current session is verified" ) # invite_callback event @@ -233,7 +250,8 @@ async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> N if type(result) == JoinError: logger.error( f"Error joining room {room.room_id} (attempt %d): %s", - attempt, result.message, + attempt, + result.message, ) else: break @@ -255,11 +273,11 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: try: client = self.client logger.debug( - f"Device Event of type {type(event)} received in " - "to_device_cb().") + f"Device Event of type {type(event)} received in " "to_device_cb()." + ) if isinstance(event, KeyVerificationStart): # first step - """ first step: receive KeyVerificationStart + """first step: receive KeyVerificationStart KeyVerificationStart( source={'content': {'method': 'm.sas.v1', @@ -289,13 +307,14 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: """ if "emoji" not in event.short_authentication_string: - estr = ("Other device does not support emoji verification " - f"{event.short_authentication_string}. Aborting.") + estr = ( + "Other device does not support emoji verification " + f"{event.short_authentication_string}. Aborting." + ) print(estr) logger.info(estr) return - resp = await client.accept_key_verification( - event.transaction_id) + resp = await client.accept_key_verification(event.transaction_id) if isinstance(resp, ToDeviceError): estr = f"accept_key_verification() failed with {resp}" print(estr) @@ -311,7 +330,7 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: logger.info(estr) elif isinstance(event, KeyVerificationCancel): # anytime - """ at any time: receive KeyVerificationCancel + """at any time: receive KeyVerificationCancel KeyVerificationCancel(source={ 'content': {'code': 'm.mismatched_sas', 'reason': 'Mismatched authentication string', @@ -328,13 +347,15 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: # client.cancel_key_verification(tx_id, reject=False) # here. The SAS flow is already cancelled. # We only need to inform the user. - estr = (f"Verification has been cancelled by {event.sender} " - f"for reason \"{event.reason}\".") + estr = ( + f"Verification has been cancelled by {event.sender} " + f'for reason "{event.reason}".' + ) print(estr) logger.info(estr) elif isinstance(event, KeyVerificationKey): # second step - """ Second step is to receive KeyVerificationKey + """Second step is to receive KeyVerificationKey KeyVerificationKey( source={'content': { 'key': 'SomeCryptoKey', @@ -359,42 +380,44 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: # automatic match, so we use y yn = "y" if yn.lower() == "y": - estr = ("Match! The verification for this " - "device will be accepted.") + estr = ( + "Match! The verification for this " "device will be accepted." + ) print(estr) logger.info(estr) - resp = await client.confirm_short_auth_string( - event.transaction_id) + resp = await client.confirm_short_auth_string(event.transaction_id) if isinstance(resp, ToDeviceError): - estr = ("confirm_short_auth_string() " - f"failed with {resp}") + estr = "confirm_short_auth_string() " f"failed with {resp}" print(estr) logger.info(estr) elif yn.lower() == "n": # no, don't match, reject - estr = ("No match! Device will NOT be verified " - "by rejecting verification.") + estr = ( + "No match! Device will NOT be verified " + "by rejecting verification." + ) print(estr) logger.info(estr) resp = await client.cancel_key_verification( - event.transaction_id, reject=True) + event.transaction_id, reject=True + ) if isinstance(resp, ToDeviceError): - estr = (f"cancel_key_verification failed with {resp}") + estr = f"cancel_key_verification failed with {resp}" print(estr) logger.info(estr) else: # C or anything for cancel - estr = ("Cancelled by user! Verification will be " - "cancelled.") + estr = "Cancelled by user! Verification will be " "cancelled." print(estr) logger.info(estr) resp = await client.cancel_key_verification( - event.transaction_id, reject=False) + event.transaction_id, reject=False + ) if isinstance(resp, ToDeviceError): - estr = (f"cancel_key_verification failed with {resp}") + estr = f"cancel_key_verification failed with {resp}" print(estr) logger.info(estr) elif isinstance(event, KeyVerificationMac): # third step - """ Third step is to receive KeyVerificationMac + """Third step is to receive KeyVerificationMac KeyVerificationMac( source={'content': { 'mac': {'ed25519:DEVICEIDXY': 'SomeKey1', @@ -414,9 +437,11 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: todevice_msg = sas.get_mac() except LocalProtocolError as e: # e.g. it might have been cancelled by ourselves - estr = (f"Cancelled or protocol error: Reason: {e}.\n" - f"Verification with {event.sender} not concluded. " - "Try again?") + estr = ( + f"Cancelled or protocol error: Reason: {e}.\n" + f"Verification with {event.sender} not concluded. " + "Try again?" + ) print(estr) logger.info(estr) else: @@ -425,25 +450,31 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: estr = f"to_device failed with {resp}" print(estr) logger.info(estr) - estr = (f"sas.we_started_it = {sas.we_started_it}\n" - f"sas.sas_accepted = {sas.sas_accepted}\n" - f"sas.canceled = {sas.canceled}\n" - f"sas.timed_out = {sas.timed_out}\n" - f"sas.verified = {sas.verified}\n" - f"sas.verified_devices = {sas.verified_devices}\n") + estr = ( + f"sas.we_started_it = {sas.we_started_it}\n" + f"sas.sas_accepted = {sas.sas_accepted}\n" + f"sas.canceled = {sas.canceled}\n" + f"sas.timed_out = {sas.timed_out}\n" + f"sas.verified = {sas.verified}\n" + f"sas.verified_devices = {sas.verified_devices}\n" + ) print(estr) logger.info(estr) - estr = ("Emoji verification was successful!\n" - "Initiate another Emoji verification from " - "another device or room if desired. " - "Or if done verifying, hit Control-C to stop the " - "bot in order to restart it as a service or to " - "run it in the background.") + estr = ( + "Emoji verification was successful!\n" + "Initiate another Emoji verification from " + "another device or room if desired. " + "Or if done verifying, hit Control-C to stop the " + "bot in order to restart it as a service or to " + "run it in the background." + ) print(estr) logger.info(estr) else: - estr = (f"Received unexpected event type {type(event)}. " - f"Event is {event}. Event will be ignored.") + estr = ( + f"Received unexpected event type {type(event)}. " + f"Event is {event}. Event will be ignored." + ) print(estr) logger.info(estr) except BaseException: @@ -452,7 +483,6 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: logger.info(estr) # bot login - async def login(self) -> None: if self.access_token is not None: logger.info("Login via access_token") @@ -480,14 +510,14 @@ async def download_mxc(self, mxc: str, filename: Optional[str] = None): # import keys async def import_keys(self): resp = await self.client.import_keys( - self.import_keys_path, - self.import_keys_password + self.import_keys_path, self.import_keys_password ) if isinstance(resp, EncryptionError): logger.error(f"import_keys failed with {resp}") else: logger.info( - f"import_keys success, please remove import_keys configuration!!!") + f"import_keys success, please remove import_keys configuration!!!" + ) # whisper function def transcribe(self, filename: str) -> str: @@ -507,30 +537,33 @@ async def main(): config = json.load(fp) bot = Bot( - homeserver=config.get('homeserver'), - user_id=config.get('user_id'), - password=config.get('password'), - device_id=config.get('device_id'), - room_id=config.get('room_id'), - access_token=config.get('access_token'), - import_keys_path=config.get('import_keys_path'), - import_keys_password=config.get('import_keys_password'), - model_size=config.get('model_size'), - device=config.get('device'), - compute_type=config.get('compute_type'), - cpu_threads=config.get('cpu_threads'), - num_workers=config.get('num_workers'), - download_root=config.get('download_root'), + homeserver=config.get("homeserver"), + user_id=config.get("user_id"), + password=config.get("password"), + device_id=config.get("device_id"), + room_id=config.get("room_id"), + access_token=config.get("access_token"), + import_keys_path=config.get("import_keys_path"), + import_keys_password=config.get("import_keys_password"), + model_size=config.get("model_size"), + device=config.get("device"), + compute_type=config.get("compute_type"), + cpu_threads=config.get("cpu_threads"), + num_workers=config.get("num_workers"), + download_root=config.get("download_root"), ) - if config.get('import_keys_path') and config.get('import_keys_password') is not None: + if ( + config.get("import_keys_path") + and config.get("import_keys_password") is not None + ): need_import_keys = True else: bot = Bot( - homeserver=os.environ.get('HOMESERVER'), - user_id=os.environ.get('USER_ID'), - password=os.environ.get('PASSWORD'), + homeserver=os.environ.get("HOMESERVER"), + user_id=os.environ.get("USER_ID"), + password=os.environ.get("PASSWORD"), device_id=os.environ.get("DEVICE_ID"), room_id=os.environ.get("ROOM_ID"), access_token=os.environ.get("ACCESS_TOKEN"), @@ -543,7 +576,10 @@ async def main(): num_workers=os.environ.get("NUM_WORKERS"), download_root=os.environ.get("DOWNLOAD_ROOT"), ) - if os.environ.get("IMPORT_KEYS_PATH") and os.environ.get("IMPORT_KEYS_PASSWORD") is not None: + if ( + os.environ.get("IMPORT_KEYS_PATH") + and os.environ.get("IMPORT_KEYS_PASSWORD") is not None + ): need_import_keys = True await bot.login() @@ -551,8 +587,21 @@ async def main(): logger.info("start import_keys process, this may take a while...") await bot.import_keys() - await bot.sync_forever() + sync_task = asyncio.create_task(bot.sync_forever()) + + # handle signal interrupt + loop = asyncio.get_running_loop() + for signame in ( + "SIGINT", + "SIGTERM", + ): + loop.add_signal_handler( + getattr(signal, signame), lambda: asyncio.create_task(bot.close(sync_task)) + ) + + await sync_task + -if __name__ == '__main__': +if __name__ == "__main__": logger.info("Bot started!") asyncio.run(main()) diff --git a/log.py b/log.py index c3f96d9..45ec996 100644 --- a/log.py +++ b/log.py @@ -9,17 +9,19 @@ def getlogger(): # create handlers warn_handler = logging.StreamHandler() info_handler = logging.StreamHandler() - error_handler = logging.FileHandler('bot.log', mode='a') + error_handler = logging.FileHandler("bot.log", mode="a") warn_handler.setLevel(logging.WARNING) error_handler.setLevel(logging.ERROR) info_handler.setLevel(logging.INFO) # create formatters warn_format = logging.Formatter( - '%(asctime)s - %(funcName)s - %(levelname)s - %(message)s') + "%(asctime)s - %(funcName)s - %(levelname)s - %(message)s" + ) error_format = logging.Formatter( - '%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s') - info_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s" + ) + info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # set formatter warn_handler.setFormatter(warn_format) diff --git a/send_message.py b/send_message.py index b99263e..792df41 100644 --- a/send_message.py +++ b/send_message.py @@ -1,23 +1,42 @@ from nio import AsyncClient -async def send_room_message(client: AsyncClient, - room_id: str, - reply_message: str, - sender_id: str = '', - reply_to_event_id: str = '', - ) -> None: - NORMAL_BODY = content = {"msgtype": "m.text", "body": reply_message, } - if reply_to_event_id == '': - content = NORMAL_BODY + +async def send_room_message( + client: AsyncClient, + room_id: str, + reply_message: str, + sender_id: str = "", + reply_to_event_id: str = "", +) -> None: + NORMAL_BODY = content = { + "msgtype": "m.text", + "body": reply_message, + } + if reply_to_event_id == "": + content = NORMAL_BODY else: - body = r'> <' + sender_id + r'> sent an audio file.\n\n' + reply_message - format = r'org.matrix.custom.html' - formatted_body = r'
In reply to ' + sender_id \ - + r'
sent an audio file.
' + reply_message + body = r"> <" + sender_id + r"> sent an audio file.\n\n" + reply_message + format = r"org.matrix.custom.html" + formatted_body = ( + r'
In reply to ' + + sender_id + + r"
sent an audio file.
" + + reply_message + ) - content = {"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body, - "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, } + content = { + "msgtype": "m.text", + "body": body, + "format": format, + "formatted_body": formatted_body, + "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, + } await client.room_send( room_id, message_type="m.room.message",