diff --git a/bot.py b/bot.py
index bd7636e..c9a3b1f 100644
--- a/bot.py
+++ b/bot.py
@@ -1,4 +1,5 @@
import os
+import signal
import sys
import traceback
from typing import Union, Optional
@@ -6,12 +7,27 @@
import asyncio
import uuid
import json
-from nio import (AsyncClient, AsyncClientConfig, InviteMemberEvent, JoinError,
- KeyVerificationCancel, KeyVerificationEvent, DownloadError,
- KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
- LocalProtocolError, LoginResponse, MatrixRoom, MegolmEvent,
- RoomMessageAudio, RoomEncryptedAudio, ToDeviceError, crypto,
- EncryptionError)
+from nio import (
+ AsyncClient,
+ AsyncClientConfig,
+ InviteMemberEvent,
+ JoinError,
+ KeyVerificationCancel,
+ KeyVerificationEvent,
+ DownloadError,
+ KeyVerificationKey,
+ KeyVerificationMac,
+ KeyVerificationStart,
+ LocalProtocolError,
+ LoginResponse,
+ MatrixRoom,
+ MegolmEvent,
+ RoomMessageAudio,
+ RoomEncryptedAudio,
+ ToDeviceError,
+ crypto,
+ EncryptionError,
+)
from nio.store.database import SqliteStore
from faster_whisper import WhisperModel
@@ -40,12 +56,11 @@ def __init__(
num_workers: int = 1,
download_root: str = "models",
):
- if (homeserver is None or user_id is None
- or device_id is None):
+ if homeserver is None or user_id is None or device_id is None:
logger.warning("homeserver && user_id && device_id is required")
sys.exit(1)
- if (password is None and access_token is None):
+ if password is None and access_token is None:
logger.warning("password or access_toekn is required")
sys.exit(1)
@@ -87,26 +102,36 @@ def __init__(
# initialize AsyncClient object
self.store_path = os.getcwd()
- self.config = AsyncClientConfig(store=SqliteStore,
- store_name="db",
- store_sync_tokens=True,
- encryption_enabled=True,
- )
- self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
- config=self.config, store_path=self.store_path,)
+ self.config = AsyncClientConfig(
+ store=SqliteStore,
+ store_name="db",
+ store_sync_tokens=True,
+ encryption_enabled=True,
+ )
+ self.client = AsyncClient(
+ homeserver=self.homeserver,
+ user=self.user_id,
+ device_id=self.device_id,
+ config=self.config,
+ store_path=self.store_path,
+ )
if self.access_token is not None:
self.client.access_token = self.access_token
# setup event callbacks
self.client.add_event_callback(
- self.message_callback, (RoomMessageAudio, RoomEncryptedAudio, ))
- self.client.add_event_callback(
- self.decryption_failure, (MegolmEvent, ))
- self.client.add_event_callback(
- self.invite_callback, (InviteMemberEvent, ))
+ self.message_callback,
+ (
+ RoomMessageAudio,
+ RoomEncryptedAudio,
+ ),
+ )
+ self.client.add_event_callback(self.decryption_failure, (MegolmEvent,))
+ self.client.add_event_callback(self.invite_callback, (InviteMemberEvent,))
self.client.add_to_device_callback(
- self.to_device_callback, (KeyVerificationEvent, ))
+ self.to_device_callback, (KeyVerificationEvent,)
+ )
# intialize whisper model
self.model = WhisperModel(
@@ -115,23 +140,19 @@ def __init__(
compute_type=self.compute_type,
cpu_threads=self.cpu_threads,
num_workers=self.num_workers,
- download_root=self.download_root,)
-
- def __del__(self):
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError as e:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(self._close())
+ download_root=self.download_root,
+ )
- async def _close(self):
+ async def close(self, task: asyncio.Task = None) -> None:
await self.client.close()
- logger.info("Bot stopped!")
+ task.cancel()
+ logger.info("Bot closed!")
+
+ # message_callback event
- # message_callback event
- async def message_callback(self, room: MatrixRoom,
- event: Union[RoomMessageAudio, RoomEncryptedAudio]) -> None:
+ async def message_callback(
+ self, room: MatrixRoom, event: Union[RoomMessageAudio, RoomEncryptedAudio]
+ ) -> None:
if self.room_id is None:
room_id = room.room_id
else:
@@ -178,12 +199,8 @@ async def message_callback(self, room: MatrixRoom,
await f.write(
crypto.attachments.decrypt_attachment(
media_data,
- event.source["content"]["file"]["key"][
- "k"
- ],
- event.source["content"]["file"]["hashes"][
- "sha256"
- ],
+ event.source["content"]["file"]["key"]["k"],
+ event.source["content"]["file"]["hashes"]["sha256"],
event.source["content"]["file"]["iv"],
)
)
@@ -216,8 +233,8 @@ async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None
return
logger.error(
- f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" +
- "Please make sure the bot current session is verified"
+ f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n"
+ + "Please make sure the bot current session is verified"
)
# invite_callback event
@@ -233,7 +250,8 @@ async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> N
if type(result) == JoinError:
logger.error(
f"Error joining room {room.room_id} (attempt %d): %s",
- attempt, result.message,
+ attempt,
+ result.message,
)
else:
break
@@ -255,11 +273,11 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
try:
client = self.client
logger.debug(
- f"Device Event of type {type(event)} received in "
- "to_device_cb().")
+ f"Device Event of type {type(event)} received in " "to_device_cb()."
+ )
if isinstance(event, KeyVerificationStart): # first step
- """ first step: receive KeyVerificationStart
+ """first step: receive KeyVerificationStart
KeyVerificationStart(
source={'content':
{'method': 'm.sas.v1',
@@ -289,13 +307,14 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
"""
if "emoji" not in event.short_authentication_string:
- estr = ("Other device does not support emoji verification "
- f"{event.short_authentication_string}. Aborting.")
+ estr = (
+ "Other device does not support emoji verification "
+ f"{event.short_authentication_string}. Aborting."
+ )
print(estr)
logger.info(estr)
return
- resp = await client.accept_key_verification(
- event.transaction_id)
+ resp = await client.accept_key_verification(event.transaction_id)
if isinstance(resp, ToDeviceError):
estr = f"accept_key_verification() failed with {resp}"
print(estr)
@@ -311,7 +330,7 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
logger.info(estr)
elif isinstance(event, KeyVerificationCancel): # anytime
- """ at any time: receive KeyVerificationCancel
+ """at any time: receive KeyVerificationCancel
KeyVerificationCancel(source={
'content': {'code': 'm.mismatched_sas',
'reason': 'Mismatched authentication string',
@@ -328,13 +347,15 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
# client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled.
# We only need to inform the user.
- estr = (f"Verification has been cancelled by {event.sender} "
- f"for reason \"{event.reason}\".")
+ estr = (
+ f"Verification has been cancelled by {event.sender} "
+ f'for reason "{event.reason}".'
+ )
print(estr)
logger.info(estr)
elif isinstance(event, KeyVerificationKey): # second step
- """ Second step is to receive KeyVerificationKey
+ """Second step is to receive KeyVerificationKey
KeyVerificationKey(
source={'content': {
'key': 'SomeCryptoKey',
@@ -359,42 +380,44 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
# automatic match, so we use y
yn = "y"
if yn.lower() == "y":
- estr = ("Match! The verification for this "
- "device will be accepted.")
+ estr = (
+ "Match! The verification for this " "device will be accepted."
+ )
print(estr)
logger.info(estr)
- resp = await client.confirm_short_auth_string(
- event.transaction_id)
+ resp = await client.confirm_short_auth_string(event.transaction_id)
if isinstance(resp, ToDeviceError):
- estr = ("confirm_short_auth_string() "
- f"failed with {resp}")
+ estr = "confirm_short_auth_string() " f"failed with {resp}"
print(estr)
logger.info(estr)
elif yn.lower() == "n": # no, don't match, reject
- estr = ("No match! Device will NOT be verified "
- "by rejecting verification.")
+ estr = (
+ "No match! Device will NOT be verified "
+ "by rejecting verification."
+ )
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
- event.transaction_id, reject=True)
+ event.transaction_id, reject=True
+ )
if isinstance(resp, ToDeviceError):
- estr = (f"cancel_key_verification failed with {resp}")
+ estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
else: # C or anything for cancel
- estr = ("Cancelled by user! Verification will be "
- "cancelled.")
+ estr = "Cancelled by user! Verification will be " "cancelled."
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
- event.transaction_id, reject=False)
+ event.transaction_id, reject=False
+ )
if isinstance(resp, ToDeviceError):
- estr = (f"cancel_key_verification failed with {resp}")
+ estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
elif isinstance(event, KeyVerificationMac): # third step
- """ Third step is to receive KeyVerificationMac
+ """Third step is to receive KeyVerificationMac
KeyVerificationMac(
source={'content': {
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
@@ -414,9 +437,11 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
todevice_msg = sas.get_mac()
except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves
- estr = (f"Cancelled or protocol error: Reason: {e}.\n"
- f"Verification with {event.sender} not concluded. "
- "Try again?")
+ estr = (
+ f"Cancelled or protocol error: Reason: {e}.\n"
+ f"Verification with {event.sender} not concluded. "
+ "Try again?"
+ )
print(estr)
logger.info(estr)
else:
@@ -425,25 +450,31 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
estr = f"to_device failed with {resp}"
print(estr)
logger.info(estr)
- estr = (f"sas.we_started_it = {sas.we_started_it}\n"
- f"sas.sas_accepted = {sas.sas_accepted}\n"
- f"sas.canceled = {sas.canceled}\n"
- f"sas.timed_out = {sas.timed_out}\n"
- f"sas.verified = {sas.verified}\n"
- f"sas.verified_devices = {sas.verified_devices}\n")
+ estr = (
+ f"sas.we_started_it = {sas.we_started_it}\n"
+ f"sas.sas_accepted = {sas.sas_accepted}\n"
+ f"sas.canceled = {sas.canceled}\n"
+ f"sas.timed_out = {sas.timed_out}\n"
+ f"sas.verified = {sas.verified}\n"
+ f"sas.verified_devices = {sas.verified_devices}\n"
+ )
print(estr)
logger.info(estr)
- estr = ("Emoji verification was successful!\n"
- "Initiate another Emoji verification from "
- "another device or room if desired. "
- "Or if done verifying, hit Control-C to stop the "
- "bot in order to restart it as a service or to "
- "run it in the background.")
+ estr = (
+ "Emoji verification was successful!\n"
+ "Initiate another Emoji verification from "
+ "another device or room if desired. "
+ "Or if done verifying, hit Control-C to stop the "
+ "bot in order to restart it as a service or to "
+ "run it in the background."
+ )
print(estr)
logger.info(estr)
else:
- estr = (f"Received unexpected event type {type(event)}. "
- f"Event is {event}. Event will be ignored.")
+ estr = (
+ f"Received unexpected event type {type(event)}. "
+ f"Event is {event}. Event will be ignored."
+ )
print(estr)
logger.info(estr)
except BaseException:
@@ -452,7 +483,6 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
logger.info(estr)
# bot login
-
async def login(self) -> None:
if self.access_token is not None:
logger.info("Login via access_token")
@@ -480,14 +510,14 @@ async def download_mxc(self, mxc: str, filename: Optional[str] = None):
# import keys
async def import_keys(self):
resp = await self.client.import_keys(
- self.import_keys_path,
- self.import_keys_password
+ self.import_keys_path, self.import_keys_password
)
if isinstance(resp, EncryptionError):
logger.error(f"import_keys failed with {resp}")
else:
logger.info(
- f"import_keys success, please remove import_keys configuration!!!")
+ f"import_keys success, please remove import_keys configuration!!!"
+ )
# whisper function
def transcribe(self, filename: str) -> str:
@@ -507,30 +537,33 @@ async def main():
config = json.load(fp)
bot = Bot(
- homeserver=config.get('homeserver'),
- user_id=config.get('user_id'),
- password=config.get('password'),
- device_id=config.get('device_id'),
- room_id=config.get('room_id'),
- access_token=config.get('access_token'),
- import_keys_path=config.get('import_keys_path'),
- import_keys_password=config.get('import_keys_password'),
- model_size=config.get('model_size'),
- device=config.get('device'),
- compute_type=config.get('compute_type'),
- cpu_threads=config.get('cpu_threads'),
- num_workers=config.get('num_workers'),
- download_root=config.get('download_root'),
+ homeserver=config.get("homeserver"),
+ user_id=config.get("user_id"),
+ password=config.get("password"),
+ device_id=config.get("device_id"),
+ room_id=config.get("room_id"),
+ access_token=config.get("access_token"),
+ import_keys_path=config.get("import_keys_path"),
+ import_keys_password=config.get("import_keys_password"),
+ model_size=config.get("model_size"),
+ device=config.get("device"),
+ compute_type=config.get("compute_type"),
+ cpu_threads=config.get("cpu_threads"),
+ num_workers=config.get("num_workers"),
+ download_root=config.get("download_root"),
)
- if config.get('import_keys_path') and config.get('import_keys_password') is not None:
+ if (
+ config.get("import_keys_path")
+ and config.get("import_keys_password") is not None
+ ):
need_import_keys = True
else:
bot = Bot(
- homeserver=os.environ.get('HOMESERVER'),
- user_id=os.environ.get('USER_ID'),
- password=os.environ.get('PASSWORD'),
+ homeserver=os.environ.get("HOMESERVER"),
+ user_id=os.environ.get("USER_ID"),
+ password=os.environ.get("PASSWORD"),
device_id=os.environ.get("DEVICE_ID"),
room_id=os.environ.get("ROOM_ID"),
access_token=os.environ.get("ACCESS_TOKEN"),
@@ -543,7 +576,10 @@ async def main():
num_workers=os.environ.get("NUM_WORKERS"),
download_root=os.environ.get("DOWNLOAD_ROOT"),
)
- if os.environ.get("IMPORT_KEYS_PATH") and os.environ.get("IMPORT_KEYS_PASSWORD") is not None:
+ if (
+ os.environ.get("IMPORT_KEYS_PATH")
+ and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
+ ):
need_import_keys = True
await bot.login()
@@ -551,8 +587,21 @@ async def main():
logger.info("start import_keys process, this may take a while...")
await bot.import_keys()
- await bot.sync_forever()
+ sync_task = asyncio.create_task(bot.sync_forever())
+
+ # handle signal interrupt
+ loop = asyncio.get_running_loop()
+ for signame in (
+ "SIGINT",
+ "SIGTERM",
+ ):
+ loop.add_signal_handler(
+ getattr(signal, signame), lambda: asyncio.create_task(bot.close(sync_task))
+ )
+
+ await sync_task
+
-if __name__ == '__main__':
+if __name__ == "__main__":
logger.info("Bot started!")
asyncio.run(main())
diff --git a/log.py b/log.py
index c3f96d9..45ec996 100644
--- a/log.py
+++ b/log.py
@@ -9,17 +9,19 @@ def getlogger():
# create handlers
warn_handler = logging.StreamHandler()
info_handler = logging.StreamHandler()
- error_handler = logging.FileHandler('bot.log', mode='a')
+ error_handler = logging.FileHandler("bot.log", mode="a")
warn_handler.setLevel(logging.WARNING)
error_handler.setLevel(logging.ERROR)
info_handler.setLevel(logging.INFO)
# create formatters
warn_format = logging.Formatter(
- '%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')
+ "%(asctime)s - %(funcName)s - %(levelname)s - %(message)s"
+ )
error_format = logging.Formatter(
- '%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')
- info_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ "%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s"
+ )
+ info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# set formatter
warn_handler.setFormatter(warn_format)
diff --git a/send_message.py b/send_message.py
index b99263e..792df41 100644
--- a/send_message.py
+++ b/send_message.py
@@ -1,23 +1,42 @@
from nio import AsyncClient
-async def send_room_message(client: AsyncClient,
- room_id: str,
- reply_message: str,
- sender_id: str = '',
- reply_to_event_id: str = '',
- ) -> None:
- NORMAL_BODY = content = {"msgtype": "m.text", "body": reply_message, }
- if reply_to_event_id == '':
- content = NORMAL_BODY
+
+async def send_room_message(
+ client: AsyncClient,
+ room_id: str,
+ reply_message: str,
+ sender_id: str = "",
+ reply_to_event_id: str = "",
+) -> None:
+ NORMAL_BODY = content = {
+ "msgtype": "m.text",
+ "body": reply_message,
+ }
+ if reply_to_event_id == "":
+ content = NORMAL_BODY
else:
- body = r'> <' + sender_id + r'> sent an audio file.\n\n' + reply_message
- format = r'org.matrix.custom.html'
- formatted_body = r'In reply to ' + sender_id \
- + r'
sent an audio file.
' + reply_message
+ body = r"> <" + sender_id + r"> sent an audio file.\n\n" + reply_message
+ format = r"org.matrix.custom.html"
+ formatted_body = (
+ r'In reply to '
+ + sender_id
+ + r"
sent an audio file.
"
+ + reply_message
+ )
- content = {"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
- "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
+ content = {
+ "msgtype": "m.text",
+ "body": body,
+ "format": format,
+ "formatted_body": formatted_body,
+ "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
+ }
await client.room_send(
room_id,
message_type="m.room.message",