diff --git a/README.md b/README.md index b76f26b5..ac347592 100644 --- a/README.md +++ b/README.md @@ -18,24 +18,33 @@ ## 功能 & 特点 - 基于异步 SQLAlchemy / MySQL 的数据存储 -- 基于群组的通知权限、命令权限以及权限等级系统 -- 基于插件节点的权限管理系统 +- 权限控制及管理系统 + - 针对不同群组可选启用通知权限、命令权限、权限等级控制 + - 针对不同好友可选启用 Bot 功能 + - 针对不同群组、好友独立配置插件权限节点 +- 支持多协议端连接, 各协议端权限及订阅配置相互独立 - 命令冷却系统 - HTTP 代理功能 - 自动处理加好友和被邀请进群 - 插件帮助功能 (支持群聊 / 私聊) - Bot对群组公告功能 (仅支持对群组) +- 定时消息功能 (仅支持对群组) +- 反闪照 (仅支持群聊) +- 反撤回 (仅支持群聊) - B站动态订阅 (建议配置B站cookies) (支持群聊 / 私聊) - B站直播间监控 (建议配置B站cookies) (支持群聊 / 私聊) +- 签到 (仅支持群聊) - 求签 (仅支持群聊) - 抽卡 (仅支持群聊) - 能不能好好说话 (lab.magiconch.com API) (支持群聊 / 私聊) - Pixiv助手 (需要 HTTP 代理, 除非部署在外网) (需要 go-cqhttp v0.9.40 及以上版本) (仅支持群聊) +- Pixiv订阅 (需要 HTTP 代理, 除非部署在外网) (仅支持群聊) - Pixivision订阅 (需要 HTTP 代理, 除非部署在外网) (仅支持群聊) - 复读姬 (仅支持群聊) - roll点抽奖 (仅支持群聊) +- ShindanMaker占卜 (shindanmaker.com / 建议使用 HTTP 代理) (仅支持群聊) - 搜番剧 (trace.moe API / 建议使用 HTTP 代理) (支持群聊 / 私聊) -- 搜二次元图 (Saucenao API 和 ascii2d / 建议使用 HTTP 代理) (支持群聊 / 私聊) +- 搜二次元图 (Saucenao API, iqbb 和 ascii2d / 建议使用 HTTP 代理) (支持群聊 / 私聊) - 来点萌图 / 来点涩图 (需要 HTTP 代理, 除非部署在外网 / 图片数据库需要自己导入) (支持群聊 / 私聊) - 表情包制作器 (支持群聊 / 私聊) - 猫按钮 (测试) (仅支持群聊) @@ -45,29 +54,21 @@ ## 如何使用 -0. 首先得有个MySQL数据库 - -1. 安装依赖: `pip install -r requirements.txt` - -2. 配置.env中数据库相关配置(必需), 其他配置可选 - -3. 运行`python bot.py` - -4. 在群组中使用 `/Omega` `/OmegaAuth` 等命令配置群组权限 +请参考本仓库 [Wiki](https://github.com/Ailitonia/omega-miya/wiki) ## 关于图片数据 如果你不想自己收集图片数据, 可以将 -[这组图片数据集](https://github.com/Ailitonia/omega-miya/raw/main/archive_data/db_pixiv.7z) -导入数据库 +[这个图片数据库](https://github.com/Ailitonia/omega-miya/raw/main/archive_data/db_pixiv.7z) +导入, 基本都是按我自己口味收集的图片 -这个图片集大概有5万条左右, 基本都是按我自己口味收集的图片 +Update 2021.8.10: 最新发布图片数据库共 9w7 条图片数据 (包含已失效或画师已删除作品) 解压后直接把 `omega_pixiv_illusts.sql` 导入对应的 pixiv_illusts 表就好了 -MD5: `8BF375B9687C397AE2040C8F9E28F68E` +MD5: `7AC9A77545E37F1B99F8D1948D0A9A78` -SHA1: `7CFF3593A85979B5D966773F3857577CFCC2FFBD` +SHA1: `1F129A18905D1590379AC761E2EAC69DAC2D42DA` 数据集来源是我的 [这个频道](https://t.me/amoeloli) diff --git a/archive_data/db_pixiv.7z b/archive_data/db_pixiv.7z index ccccb755..a2dff548 100644 Binary files a/archive_data/db_pixiv.7z and b/archive_data/db_pixiv.7z differ diff --git a/archive_data/omega_recommend_pixiv_illust_filtered.7z b/archive_data/omega_recommend_pixiv_illust_filtered.7z deleted file mode 100644 index 7cd30dee..00000000 Binary files a/archive_data/omega_recommend_pixiv_illust_filtered.7z and /dev/null differ diff --git a/bot.py b/bot.py index 663b85d6..652033bb 100644 --- a/bot.py +++ b/bot.py @@ -29,6 +29,11 @@ logger.add(log_info_path, rotation="00:00", diagnose=False, level="INFO", format=default_format, encoding='utf-8') logger.add(log_error_path, rotation="00:00", diagnose=False, level="ERROR", format=default_format, encoding='utf-8') +# Add extra debug log file +# log_debug_name = f"{datetime.today().strftime('%Y%m%d-%H%M%S')}-DEBUG.log" +# log_debug_path = os.path.join(bot_log_path, log_debug_name) +# logger.add(log_debug_path, rotation="00:00", diagnose=False, level="DEBUG", format=default_format, encoding='utf-8') + # You can pass some keyword args config to init function nonebot.init() @@ -37,7 +42,6 @@ config.root_path_ = bot_root_path config.tmp_path_ = bot_tmp_path - # 注册 cqhttp adapter driver = nonebot.get_driver() driver.register_adapter("cqhttp", CQHTTPBot) diff --git a/omega_miya/plugins/Omega_anti_flash/__init__.py b/omega_miya/plugins/Omega_anti_flash/__init__.py new file mode 100644 index 00000000..e4aea253 --- /dev/null +++ b/omega_miya/plugins/Omega_anti_flash/__init__.py @@ -0,0 +1,133 @@ +from nonebot import MatcherGroup, export, logger +from nonebot.permission import SUPERUSER +from nonebot.typing import T_State +from nonebot.adapters.cqhttp import Message +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP, GROUP_ADMIN, GROUP_OWNER +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBAuth, Result +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, OmegaRules + + +# Custom plugin usage text +__plugin_raw_name__ = __name__.split('.')[-1] +__plugin_name__ = 'AntiFlash' +__plugin_usage__ = r'''【AntiFlash 反闪照】 +检测闪照并提取原图 + +**Permission** +Group only with +AuthNode + +**AuthNode** +basic + +**Usage** +**GroupAdmin and SuperUser Only** +/AntiFlash ''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +# 注册事件响应器 +AntiFlash = MatcherGroup(type='message', permission=GROUP, priority=100, block=False) + +anti_flash_admin = AntiFlash.on_command( + 'AntiFlash', + aliases={'antiflash', '反闪照'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='anti_flash', + command=True, + level=10), + permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, + priority=10, + block=True) + + +# 修改默认参数处理 +@anti_flash_admin.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await anti_flash_admin.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await anti_flash_admin.finish('操作已取消') + + +@anti_flash_admin.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['sub_command'] = args[0] + else: + await anti_flash_admin.finish('参数错误QAQ') + + +@anti_flash_admin.got('sub_command', prompt='执行操作?\n【ON/OFF】') +async def handle_sub_command_args(bot: Bot, event: GroupMessageEvent, state: T_State): + sub_command = state['sub_command'] + if sub_command not in ['on', 'off']: + await anti_flash_admin.reject('没有这个选项哦, 请在【ON/OFF】中选择并重新发送, 取消命令请发送【取消】:') + + if sub_command == 'on': + _res = await anti_flash_on(bot=bot, event=event, state=state) + elif sub_command == 'off': + _res = await anti_flash_off(bot=bot, event=event, state=state) + else: + _res = Result.IntResult(error=True, info='Unknown error, except sub_command', result=-1) + + if _res.success(): + logger.info(f"设置 AntiFlash 状态为 {sub_command} 成功, group_id: {event.group_id}, {_res.info}") + await anti_flash_admin.finish(f'已设置 AntiFlash 状态为 {sub_command}!') + else: + logger.error(f"设置 AntiFlash 状态为 {sub_command} 失败, group_id: {event.group_id}, {_res.info}") + await anti_flash_admin.finish(f'设置 AntiFlash 状态失败了QAQ, 请稍后再试~') + + +async def anti_flash_on(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=1, deny_tag=0, auth_info='启用反闪照') + return result + + +async def anti_flash_off(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=0, deny_tag=1, auth_info='禁用反闪照') + return result + + +anti_flash_handler = AntiFlash.on_message(rule=OmegaRules.has_auth_node(__plugin_raw_name__, 'basic')) + + +@anti_flash_handler.handle() +async def check_flash_img(bot: Bot, event: GroupMessageEvent, state: T_State): + for msg_seg in event.message: + if msg_seg.type == 'image': + if msg_seg.data.get('type') == 'flash': + msg = Message('AntiFlash 已检测到闪照:\n').append(str(msg_seg).replace(',type=flash', '')) + logger.info(f'AntiFlash 已处理闪照, message_id: {event.message_id}') + await anti_flash_handler.finish(msg) diff --git a/omega_miya/plugins/Omega_anti_recall/__init__.py b/omega_miya/plugins/Omega_anti_recall/__init__.py new file mode 100644 index 00000000..2f516653 --- /dev/null +++ b/omega_miya/plugins/Omega_anti_recall/__init__.py @@ -0,0 +1,143 @@ +from nonebot import on_command, on_notice, export, logger +from nonebot.permission import SUPERUSER +from nonebot.typing import T_State +from nonebot.adapters.cqhttp import Message +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import GroupMessageEvent, GroupRecallNoticeEvent +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBAuth, DBHistory, Result +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, OmegaRules + + +# Custom plugin usage text +__plugin_raw_name__ = __name__.split('.')[-1] +__plugin_name__ = 'AntiRecall' +__plugin_usage__ = r'''【AntiRecall 反撤回】 +检测消息撤回并提取原消息 + +**Permission** +Group only with +AuthNode + +**AuthNode** +basic + +**Usage** +**GroupAdmin and SuperUser Only** +/AntiRecall ''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +# 注册事件响应器 +anti_recall_admin = on_command( + 'AntiRecall', + aliases={'antirecall', '反撤回'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='anti_recall', + command=True, + level=10), + permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, + priority=10, + block=True) + + +# 修改默认参数处理 +@anti_recall_admin.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await anti_recall_admin.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await anti_recall_admin.finish('操作已取消') + + +@anti_recall_admin.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['sub_command'] = args[0] + else: + await anti_recall_admin.finish('参数错误QAQ') + + +@anti_recall_admin.got('sub_command', prompt='执行操作?\n【ON/OFF】') +async def handle_sub_command_args(bot: Bot, event: GroupMessageEvent, state: T_State): + sub_command = state['sub_command'] + if sub_command not in ['on', 'off']: + await anti_recall_admin.reject('没有这个选项哦, 请在【ON/OFF】中选择并重新发送, 取消命令请发送【取消】:') + + if sub_command == 'on': + _res = await anti_recall_on(bot=bot, event=event, state=state) + elif sub_command == 'off': + _res = await anti_recall_off(bot=bot, event=event, state=state) + else: + _res = Result.IntResult(error=True, info='Unknown error, except sub_command', result=-1) + + if _res.success(): + logger.info(f"设置 AntiRecall 状态为 {sub_command} 成功, group_id: {event.group_id}, {_res.info}") + await anti_recall_admin.finish(f'已设置 AntiRecall 状态为 {sub_command}!') + else: + logger.error(f"设置 AntiRecall 状态为 {sub_command} 失败, group_id: {event.group_id}, {_res.info}") + await anti_recall_admin.finish(f'设置 AntiRecall 状态失败了QAQ, 请稍后再试~') + + +async def anti_recall_on(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=1, deny_tag=0, auth_info='启用反撤回') + return result + + +async def anti_recall_off(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=0, deny_tag=1, auth_info='禁用反撤回') + return result + + +anti_recall_handler = on_notice(rule=OmegaRules.has_auth_node(__plugin_raw_name__, 'basic'), priority=100, block=False) + + +@anti_recall_handler.handle() +async def check_recall_notice(bot: Bot, event: GroupRecallNoticeEvent, state: T_State): + self_id = event.self_id + group_id = event.group_id + user_id = event.user_id + message_id = event.message_id + history_result = await DBHistory.search_unique_msg( + self_id=self_id, post_type='message', detail_type='group', sub_type='normal', + event_id=message_id, group_id=group_id, user_id=user_id) + if history_result.error: + logger.error(f'AntiRecall 查询历史消息失败, message_id: {message_id}, error: {history_result.info}') + return + else: + history = history_result.result + user_name = history.user_name + time = history.created_at + msg = history.msg_data + send_msg = Message(f"AntiRecall 已检测到撤回消息:\n{time}@{user_name}:\n").append(msg) + logger.info(f'AntiRecall 已处理撤回消息, message_id: {message_id}') + await anti_recall_handler.finish(send_msg) diff --git a/omega_miya/plugins/Omega_auth_manage/__init__.py b/omega_miya/plugins/Omega_auth_manager/__init__.py similarity index 83% rename from omega_miya/plugins/Omega_auth_manage/__init__.py rename to omega_miya/plugins/Omega_auth_manager/__init__.py index 59a954ef..562e7b0f 100644 --- a/omega_miya/plugins/Omega_auth_manage/__init__.py +++ b/omega_miya/plugins/Omega_auth_manager/__init__.py @@ -4,8 +4,8 @@ from nonebot.permission import SUPERUSER from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot -from nonebot.adapters.cqhttp.event import MessageEvent -from omega_miya.utils.Omega_Base import DBUser, DBGroup, DBAuth +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from omega_miya.utils.Omega_Base import DBBot, DBFriend, DBBotGroup, DBAuth from omega_miya.utils.Omega_plugin_utils import init_export @@ -49,31 +49,27 @@ async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): await omegaauth.finish('参数错误QAQ') -@omegaauth.got('sub_command', prompt='执行操作?\n【allow/deny/clear/list】') +# 处理显示权限节点列表事件 +@omegaauth.got('sub_command', prompt='执行操作?\n【allow/deny/clear/custom_*/list】') async def handle_sub_command(bot: Bot, event: MessageEvent, state: T_State): sub_command = state["sub_command"] - if sub_command not in ['allow', 'deny', 'clear', 'list']: + if sub_command not in ['allow', 'deny', 'clear', 'list', 'custom_allow', 'custom_deny', 'custom_clear']: await omegaauth.finish('参数错误QAQ') - -# 处理显示权限节点列表事件 -@omegaauth.got('sub_command', prompt='list:') -async def handle_list_node(bot: Bot, event: MessageEvent, state: T_State): - sub_command = state["sub_command"] + self_bot = DBBot(self_qq=int(bot.self_id)) if sub_command == 'list': - detail_type = event.dict().get(f'{event.get_type()}_type') - if detail_type == 'group': - group_id = event.dict().get('group_id') - _res = await DBAuth.list(auth_type='group', auth_id=group_id) + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + _res = await DBAuth.list(auth_type='group', auth_id=group_id, self_bot=self_bot) if _res.success(): node_text = '\n'.join('/'.join(map(str, n)) for n in _res.result) msg = f'当前群组权限列表为:\n\n{node_text}' await omegaauth.finish(msg) else: await omegaauth.finish('发生了意外的错误QAQ, 请稍后再试') - elif detail_type == 'private': - user_id = event.dict().get('user_id') - _res = await DBAuth.list(auth_type='user', auth_id=user_id) + elif isinstance(event, PrivateMessageEvent): + user_id = event.user_id + _res = await DBAuth.list(auth_type='user', auth_id=user_id, self_bot=self_bot) if _res.success(): node_text = '\n'.join('/'.join(map(str, n)) for n in _res.result) msg = f'当前用户权限列表为:\n\n{node_text}' @@ -82,6 +78,13 @@ async def handle_list_node(bot: Bot, event: MessageEvent, state: T_State): await omegaauth.finish('发生了意外的错误QAQ, 请稍后再试') else: await omegaauth.finish('非授权会话, 操作中止') + elif sub_command in ['allow', 'deny', 'clear']: + if isinstance(event, GroupMessageEvent): + state["auth_type"] = 'group' + state["auth_id"] = str(event.group_id) + elif isinstance(event, PrivateMessageEvent): + state["auth_type"] = 'user' + state["auth_id"] = str(event.user_id) @omegaauth.got('auth_type', prompt='授权类型?\n【user/group】') @@ -95,11 +98,12 @@ async def handle_auth_type(bot: Bot, event: MessageEvent, state: T_State): async def handle_auth_id(bot: Bot, event: MessageEvent, state: T_State): auth_type = state["auth_type"] auth_id = state["auth_id"] + self_bot = DBBot(self_qq=int(bot.self_id)) if not re.match(r'^\d+$', auth_id): await omegaauth.finish('参数错误QAQ, qq或群号应为纯数字') if auth_type == 'user': - user = DBUser(user_id=auth_id) + user = DBFriend(user_id=auth_id, self_bot=self_bot) user_name_res = await user.nickname() if user_name_res.success(): await omegaauth.send(f'即将对用户: 【{user_name_res.result}】执行操作') @@ -107,7 +111,7 @@ async def handle_auth_id(bot: Bot, event: MessageEvent, state: T_State): logger.error(f'为 {auth_type}/{auth_id} 配置权限节点失败, 数据库中不存在该用户') await omegaauth.finish('数据库中不存在该用户QAQ') elif auth_type == 'group': - group = DBGroup(group_id=auth_id) + group = DBBotGroup(group_id=auth_id, self_bot=self_bot) group_name_res = await group.name() if group_name_res.success(): await omegaauth.send(f'即将对群组: 【{group_name_res.result}】执行操作') @@ -152,10 +156,11 @@ async def handle_auth_node(bot: Bot, event: MessageEvent, state: T_State): r_auth_node = '.'.join([plugin, auth_node]) auth_id = state["auth_id"] - sub_command = state["sub_command"] + sub_command = re.sub(r'^custom_', '', str(state["sub_command"])) auth_type = state["auth_type"] + self_bot = DBBot(self_qq=int(bot.self_id)) - auth = DBAuth(auth_id=auth_id, auth_type=auth_type, auth_node=r_auth_node) + auth = DBAuth(auth_id=auth_id, auth_type=auth_type, auth_node=r_auth_node, self_bot=self_bot) if sub_command == 'allow': res = await auth.set(allow_tag=1, deny_tag=0) @@ -164,7 +169,7 @@ async def handle_auth_node(bot: Bot, event: MessageEvent, state: T_State): elif sub_command == 'clear': res = await auth.delete() else: - logger.error(f'handle_auth_node 执行时 sub_command 变量检验错误') + logger.error(f'handle_auth_node 执行时 sub_command 命令检验错误') return if res.success(): diff --git a/omega_miya/plugins/Omega_auto_manager/__init__.py b/omega_miya/plugins/Omega_auto_manager/__init__.py new file mode 100644 index 00000000..80a07114 --- /dev/null +++ b/omega_miya/plugins/Omega_auto_manager/__init__.py @@ -0,0 +1,37 @@ +""" +@Author : Ailitonia +@Date : 2021/06/09 19:10 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : Omega 自动化综合/群管/好友管理插件 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import export +from omega_miya.utils.Omega_plugin_utils import init_export +from .group_welcome_message import * +from .invite_request import * + + +# Custom plugin usage text +__plugin_name__ = 'OmegaAutoManager' +__plugin_usage__ = r'''【Omega 自动化综合/群管/好友管理插件】 +Omega机器人自动化综合/群管/好友管理 + +以下命令均需要@机器人 +**Usage** +**GroupAdmin and SuperUser Only** +/设置欢迎消息 +/清空欢迎消息 +''' + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__) + + +__all__ = [ + 'WelcomeMsg', + 'group_increase', + 'add_and_invite_request' +] diff --git a/omega_miya/plugins/Omega_auto_manager/group_welcome_message.py b/omega_miya/plugins/Omega_auto_manager/group_welcome_message.py new file mode 100644 index 00000000..4049a2cd --- /dev/null +++ b/omega_miya/plugins/Omega_auto_manager/group_welcome_message.py @@ -0,0 +1,112 @@ +""" +@Author : Ailitonia +@Date : 2021/06/11 23:42 +@FileName : group_welcome_message.py +@Project : nonebot2_miya +@Description : 群自定义欢迎信息 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import logger +from nonebot.plugin import on_notice, CommandGroup +from nonebot.typing import T_State +from nonebot.rule import to_me +from nonebot.permission import SUPERUSER +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.message import Message, MessageSegment +from nonebot.adapters.cqhttp.event import GroupMessageEvent, GroupIncreaseNoticeEvent +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup +from omega_miya.utils.Omega_plugin_utils import OmegaRules + + +SETTING_NAME: str = 'group_welcome_message' +DEFAULT_WELCOME_MSG: str = '欢迎新朋友~\n进群请先看群公告~\n一起愉快地聊天吧!' + + +# 注册事件响应器 +WelcomeMsg = CommandGroup( + 'WelcomeMsg', + rule=to_me(), + permission=SUPERUSER | GROUP_ADMIN | GROUP_OWNER, + priority=10, + block=True +) + +welcome_msg_set = WelcomeMsg.command('set', aliases={'设置欢迎消息'}) +welcome_msg_clear = WelcomeMsg.command('clear', aliases={'清空欢迎消息'}) + + +# 修改默认参数处理 +@welcome_msg_set.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_message()).strip() + if not args: + await welcome_msg_set.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args + if state[state["_current_key"]] == '取消': + await welcome_msg_set.finish('操作已取消') + + +@welcome_msg_set.got('welcome_msg', prompt='请发送你要设置的欢迎消息:') +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + welcome_msg = state['welcome_msg'] + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + msg_set_result = await group.setting_set(setting_name=SETTING_NAME, main_config='Custom', + extra_config=welcome_msg, setting_info='自定义群组欢迎信息') + if msg_set_result.success(): + logger.info(f'已为群组: {group_id} 设置自定义欢迎信息: {welcome_msg}') + await welcome_msg_set.finish(f'已为本群组设定了自定义欢迎信息!') + else: + logger.error(f'为群组: {group_id} 设置自定义欢迎信息失败, error info: {msg_set_result.info}') + await welcome_msg_set.finish(f'为本群组设定自定义欢迎信息失败了QAQ, 请稍后再试或联系管理员处理') + + +@welcome_msg_clear.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + msg_set_result = await group.setting_del(setting_name=SETTING_NAME) + if msg_set_result.success(): + logger.info(f'已为群组: {group_id} 清除自定义欢迎信息') + await welcome_msg_clear.finish(f'已清除了本群组设定的自定义欢迎信息!') + else: + logger.error(f'为群组: {group_id} 清除自定义欢迎信息失败, error info: {msg_set_result.info}') + await welcome_msg_clear.finish(f'为本群组清除自定义欢迎信息失败了QAQ, 请稍后再试或联系管理员处理') + + +# 注册事件响应器, 新增群成员 +group_increase = on_notice(priority=100, rule=OmegaRules.has_group_command_permission()) + + +@group_increase.handle() +async def handle_group_increase(bot: Bot, event: GroupIncreaseNoticeEvent, state: T_State): + user_id = event.user_id + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + # 获取自定义欢迎信息 + welcome_msg_result = await group.setting_get(setting_name=SETTING_NAME) + if welcome_msg_result.success(): + main, second, extra = welcome_msg_result.result + if extra: + msg = extra + else: + msg = DEFAULT_WELCOME_MSG + else: + msg = DEFAULT_WELCOME_MSG + + # 发送欢迎消息 + at_seg = MessageSegment.at(user_id=user_id) + await bot.send(event=event, message=Message(at_seg).append(msg)) + logger.info(f'群组: {group_id}, 有新用户: {user_id} 进群') + + +__all__ = [ + 'WelcomeMsg', + 'group_increase' +] diff --git a/omega_miya/utils/Omega_auto_manager/__init__.py b/omega_miya/plugins/Omega_auto_manager/invite_request.py similarity index 66% rename from omega_miya/utils/Omega_auto_manager/__init__.py rename to omega_miya/plugins/Omega_auto_manager/invite_request.py index 0a0812a5..a10a655c 100644 --- a/omega_miya/utils/Omega_auto_manager/__init__.py +++ b/omega_miya/plugins/Omega_auto_manager/invite_request.py @@ -1,9 +1,19 @@ +""" +@Author : Ailitonia +@Date : 2021/06/12 0:28 +@FileName : invite_request.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + from nonebot import on_request, on_notice, logger from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.message import MessageSegment, Message from nonebot.adapters.cqhttp.event import FriendRequestEvent, GroupRequestEvent, GroupIncreaseNoticeEvent -from omega_miya.utils.Omega_Base import DBGroup +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup # 注册事件响应器 add_and_invite_request = on_request(priority=100) @@ -39,19 +49,6 @@ async def handle_group_invite(bot: Bot, event: GroupRequestEvent, state: T_State logger.info(f'已处理群组请求, 被用户: {user_id} 邀请加入群组: {group_id}.') -# 注册事件响应器, 新增群成员 -group_increase = on_notice(priority=100) - - -@group_increase.handle() -async def handle_group_increase(bot: Bot, event: GroupIncreaseNoticeEvent, state: T_State): - user_id = event.user_id - group_id = event.group_id - detail_type = event.notice_type - group_c_permission_res = await DBGroup(group_id=group_id).permission_command() - if detail_type == 'group_increase' and group_c_permission_res.result == 1: - # 发送欢迎消息 - at_seg = MessageSegment.at(user_id=user_id) - msg = f'{at_seg}欢迎新朋友~\n进群请先看群公告~\n一起愉快地聊天吧!' - await bot.send(event=event, message=Message(msg)) - logger.info(f'群组: {group_id}, 有新用户: {user_id} 进群') +__all__ = [ + 'add_and_invite_request' +] diff --git a/omega_miya/plugins/Omega_email/__init__.py b/omega_miya/plugins/Omega_email/__init__.py index 6dd20fc9..5ef0a42c 100644 --- a/omega_miya/plugins/Omega_email/__init__.py +++ b/omega_miya/plugins/Omega_email/__init__.py @@ -1,4 +1,5 @@ import re +import pathlib from nonebot import MatcherGroup, export, logger from nonebot.rule import to_me from nonebot.permission import SUPERUSER @@ -7,7 +8,7 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent from nonebot.adapters.cqhttp.permission import GROUP -from omega_miya.utils.Omega_Base import DBEmailBox, DBGroup +from omega_miya.utils.Omega_Base import DBEmailBox, DBBot, DBBotGroup from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state from omega_miya.utils.text_to_img import text_to_img from .utils import check_mailbox, get_unseen_mail_info, encrypt_password, decrypt_password @@ -126,13 +127,14 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat async def handle_admin_mail_bind(bot: Bot, event: GroupMessageEvent, state: T_State): mailbox_list = state['mailbox_list'] email_address = state['email_address'] + self_bot = DBBot(self_qq=int(bot.self_id)) if email_address not in mailbox_list: logger.warning(f'Group:{event.group_id}/User:{event.user_id} 绑定邮箱: {email_address} 失败, 不在可绑定邮箱中的邮箱') await admin_mail_bind.finish('该邮箱不在可绑定邮箱中!') group_id = event.group_id - res = await DBGroup(group_id=group_id).mailbox_add(mailbox=DBEmailBox(address=email_address)) + res = await DBBotGroup(group_id=group_id, self_bot=self_bot).mailbox_add(mailbox=DBEmailBox(address=email_address)) if res.success(): logger.info(f'Group:{event.group_id}/User:{event.user_id} 绑定邮箱: {email_address} 成功') @@ -152,7 +154,8 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat await mail_receive.finish('该命令不支持参数QAQ') group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) res = await group.mailbox_clear() if res.success(): @@ -185,7 +188,8 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat await mail_receive.finish('该命令不支持参数QAQ') group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) group_bind_mailbox = await group.mailbox_list() if not group_bind_mailbox.success() or not group_bind_mailbox.result: logger.info(f'{group_id} 收邮件失败: 没有绑定的邮箱') @@ -230,7 +234,8 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat text_img_result = await text_to_img(text=msg) if text_img_result.error: raise Exception(f'Text to img failed, {text_img_result.info}') - img_seg = MessageSegment.image(f'file:///{text_img_result.result}') + file_url = pathlib.Path(text_img_result.result).as_uri() + img_seg = MessageSegment.image(file=file_url) await mail_receive.send(img_seg) except Exception as e: logger.error(f'发送邮件信息失败, {repr(e)}') diff --git a/omega_miya/plugins/Omega_email/imap.py b/omega_miya/plugins/Omega_email/imap.py index 934f14bc..ac8df838 100644 --- a/omega_miya/plugins/Omega_email/imap.py +++ b/omega_miya/plugins/Omega_email/imap.py @@ -47,7 +47,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_mail_info(self, charset, *criteria) -> List[Email]: self.__mail.login(self.__address, self.__password) - if self.__address.endswith('163.com'): + if self.__address.endswith('@163.com'): # 添加163邮箱 IMAP ID 验证 imaplib.Commands['ID'] = ('AUTH',) args = ("name", "omega", "contact", "omega_miya@163.com", "version", "1.0.2", "vendor", "pyimaplibclient") diff --git a/omega_miya/plugins/Omega_manage/__init__.py b/omega_miya/plugins/Omega_manager/__init__.py similarity index 81% rename from omega_miya/plugins/Omega_manage/__init__.py rename to omega_miya/plugins/Omega_manager/__init__.py index 98b5c2d2..3e487eb0 100644 --- a/omega_miya/plugins/Omega_manage/__init__.py +++ b/omega_miya/plugins/Omega_manager/__init__.py @@ -14,7 +14,7 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER, PRIVATE_FRIEND -from omega_miya.utils.Omega_Base import DBGroup, DBUser, DBAuth, DBFriend, Result +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBUser, DBAuth, DBFriend, Result from omega_miya.utils.Omega_plugin_utils import init_export from .sys_background_scheduled import scheduler @@ -141,8 +141,9 @@ async def handle_sub_command(bot: Bot, event: PrivateMessageEvent, state: T_Stat async def friend_init(bot: Bot, event: PrivateMessageEvent, state: T_State) -> Result.TextResult: user_id = event.user_id + self_bot = DBBot(self_qq=int(bot.self_id)) # 调用api获取好友列表 - friends_list = await bot.call_api('get_friend_list') + friends_list = await bot.get_friend_list() actual_friend_list = [int(x.get('user_id')) for x in friends_list] if user_id not in actual_friend_list: return Result.TextResult(error=True, info='Not in friends list', result='错误, 不在好友列表中') @@ -151,7 +152,7 @@ async def friend_init(bot: Bot, event: PrivateMessageEvent, state: T_State) -> R nickname = user_info.get('nickname') remark = user_info.get('remark') - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) # 更新用户表 add_user_result = await friend.add(nickname=nickname) @@ -159,7 +160,7 @@ async def friend_init(bot: Bot, event: PrivateMessageEvent, state: T_State) -> R return Result.TextResult(error=True, info=add_user_result.info, result='错误, 请联系管理员处理') # 初始化好友authnode - await init_user_auth_node(user_id=user_id) + await init_user_auth_node(user_id=user_id, self_bot=self_bot) set_friend_result = await friend.set_friend(nickname=nickname, remark=remark, private_permissions=1) if set_friend_result.success(): @@ -170,11 +171,11 @@ async def friend_init(bot: Bot, event: PrivateMessageEvent, state: T_State) -> R async def friend_private_enable(bot: Bot, event: PrivateMessageEvent, state: T_State) -> Result.TextResult: user_id = event.user_id - + self_bot = DBBot(self_qq=int(bot.self_id)) # 初始化好友authnode - await init_user_auth_node(user_id=user_id) + await init_user_auth_node(user_id=user_id, self_bot=self_bot) - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) result = await friend.set_private_permission(private_permissions=1) if result.success(): return Result.TextResult(error=False, info='Success', result='成功, 已启用私聊功能, 权限节点已设置为默认值') @@ -184,11 +185,11 @@ async def friend_private_enable(bot: Bot, event: PrivateMessageEvent, state: T_S async def friend_private_disable(bot: Bot, event: PrivateMessageEvent, state: T_State) -> Result.TextResult: user_id = event.user_id - + self_bot = DBBot(self_qq=int(bot.self_id)) # 初始化好友authnode - await init_user_auth_node(user_id=user_id) + await init_user_auth_node(user_id=user_id, self_bot=self_bot) - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) result = await friend.set_private_permission(private_permissions=0) if result.success(): return Result.TextResult(error=False, info='Success', result='成功, 已禁用私聊功能, 权限节点已重置为默认值') @@ -198,29 +199,35 @@ async def friend_private_disable(bot: Bot, event: PrivateMessageEvent, state: T_ async def group_init(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) # 调用api获取群信息 - group_info = await bot.call_api(api='get_group_info', group_id=group_id) + group_info = await bot.get_group_info(group_id=group_id) group_name = group_info['group_name'] - group = DBGroup(group_id=group_id) + group_memo = group_info.get('group_memo') + group = DBBotGroup(group_id=group_id, self_bot=self_bot) # 添加并初始化群信息 _result = await group.add(name=group_name) if not _result.success(): return Result.IntResult(True, _result.info, -1) + _result = await group.set_bot_group(group_memo=group_memo) + if not _result.success(): + return Result.IntResult(True, _result.info, -1) + _result = await group.permission_set(notice=1, command=1, level=10) if not _result.success(): return Result.IntResult(True, _result.info, -1) # 初始化群组authnode - await init_group_auth_node(group_id=group_id) + await init_group_auth_node(group_id=group_id, self_bot=self_bot) _result = await group.member_clear() if not _result.success(): return Result.IntResult(True, _result.info, -1) # 添加用户 - group_member_list = await bot.call_api(api='get_group_member_list', group_id=group_id) + group_member_list = await bot.get_group_member_list(group_id=group_id) failed_user = [] for user_info in group_member_list: # 用户信息 @@ -234,13 +241,13 @@ async def group_init(bot: Bot, event: GroupMessageEvent, state: T_State) -> Resu _result = await _user.add(nickname=user_nickname) if not _result.success(): failed_user.append(_user.qq) - logger.warning(f'User: {user_qq}, {_result.info}') + logger.warning(f'Add group user: {user_qq}, {_result.info}') continue _result = await group.member_add(user=_user, user_group_nickname=user_group_nickmane) if not _result.success(): failed_user.append(_user.qq) - logger.warning(f'User: {user_qq}, {_result.info}') + logger.warning(f'Upgrade group user: {user_qq}, {_result.info}') await group.init_member_status() @@ -249,30 +256,30 @@ async def group_init(bot: Bot, event: GroupMessageEvent, state: T_State) -> Resu async def group_upgrade(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) # 调用api获取群信息 - group_info = await bot.call_api(api='get_group_info', group_id=group_id) + group_info = await bot.get_group_info(group_id=group_id) group_name = group_info['group_name'] - group = DBGroup(group_id=group_id) + group_memo = group_info.get('group_memo') + group = DBBotGroup(group_id=group_id, self_bot=self_bot) # 更新群信息 _result = await group.add(name=group_name) if not _result.success(): return Result.IntResult(True, _result.info, -1) + _result = await group.set_bot_group(group_memo=group_memo) + if not _result.success(): + return Result.IntResult(True, _result.info, -1) + # 更新用户 - group_member_list = await bot.call_api(api='get_group_member_list', group_id=group_id) + group_member_list = await bot.get_group_member_list(group_id=group_id) failed_user = [] # 首先清除数据库中退群成员 - exist_member_list = [] - for user_info in group_member_list: - user_qq = user_info['user_id'] - exist_member_list.append(int(user_qq)) - - db_member_list = [] + exist_member_list = [int(x.get('user_id')) for x in group_member_list] member_res = await group.member_list() - for user_id, nickname in member_res.result: - db_member_list.append(user_id) + db_member_list = [user_id for user_id, nickname in member_res.result] del_member_list = list(set(db_member_list).difference(set(exist_member_list))) for user_id in del_member_list: @@ -281,9 +288,9 @@ async def group_upgrade(bot: Bot, event: GroupMessageEvent, state: T_State) -> R # 更新群成员 for user_info in group_member_list: # 用户信息 - user_qq = user_info['user_id'] - user_nickname = user_info['nickname'] - user_group_nickmane = user_info['card'] + user_qq = user_info.get('user_id') + user_nickname = user_info.get('nickname') + user_group_nickmane = user_info.get('card') if not user_group_nickmane: user_group_nickmane = user_nickname @@ -291,13 +298,13 @@ async def group_upgrade(bot: Bot, event: GroupMessageEvent, state: T_State) -> R _result = await _user.add(nickname=user_nickname) if not _result.success(): failed_user.append(_user.qq) - logger.warning(f'User: {user_qq}, {_result.info}') + logger.warning(f'Add group user: {user_qq}, {_result.info}') continue _result = await group.member_add(user=_user, user_group_nickname=user_group_nickmane) if not _result.success(): failed_user.append(_user.qq) - logger.warning(f'User: {user_qq}, {_result.info}') + logger.warning(f'Upgrade group user: {user_qq}, {_result.info}') await group.init_member_status() @@ -306,7 +313,8 @@ async def group_upgrade(bot: Bot, event: GroupMessageEvent, state: T_State) -> R async def set_group_notice(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) permission_res = await group.permission_info() if permission_res.error: return Result.IntResult(True, permission_res.info, -1) @@ -325,7 +333,8 @@ async def set_group_notice(bot: Bot, event: GroupMessageEvent, state: T_State) - async def set_group_command(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) permission_res = await group.permission_info() if permission_res.error: return Result.IntResult(True, permission_res.info, -1) @@ -333,8 +342,12 @@ async def set_group_command(bot: Bot, event: GroupMessageEvent, state: T_State) group_notice, _command, group_level = permission_res.result if state['sub_arg'] == 'on': + # 初始化群组authnode + await init_group_auth_node(group_id=group_id, self_bot=self_bot) result = await group.permission_set(notice=group_notice, command=1, level=group_level) elif state['sub_arg'] == 'off': + # 初始化群组authnode + await init_group_auth_node(group_id=group_id, self_bot=self_bot) result = await group.permission_set(notice=group_notice, command=0, level=group_level) else: result = Result.IntResult(True, 'Missing parameters or Illegal parameter', -1) @@ -344,7 +357,8 @@ async def set_group_command(bot: Bot, event: GroupMessageEvent, state: T_State) async def set_group_level(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) permission_res = await group.permission_info() if permission_res.error: return Result.IntResult(True, permission_res.info, -1) @@ -362,7 +376,8 @@ async def set_group_level(bot: Bot, event: GroupMessageEvent, state: T_State) -> async def show_group_permission(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.TextResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) permission_res = await group.permission_info() if permission_res.error: return Result.TextResult(True, permission_res.info, '') @@ -376,14 +391,13 @@ async def show_group_permission(bot: Bot, event: GroupMessageEvent, state: T_Sta async def reset_group_permission(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) - + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) result = await group.permission_reset() - return result -async def init_group_auth_node(group_id: int): +async def init_group_auth_node(group_id: int, self_bot: DBBot): """ 为群组配置权限节点默认值 """ @@ -398,17 +412,18 @@ class AuthNode: AuthNode(node='Omega_help.skip_cd', allow_tag=1, deny_tag=0, auth_info='默认规则: help免cd'), AuthNode(node='nhentai.basic', allow_tag=0, deny_tag=1, auth_info='默认规则: 禁用nhentai'), AuthNode(node='setu.setu', allow_tag=0, deny_tag=1, auth_info='默认规则: 禁用setu'), + AuthNode(node='setu.allow_r18', allow_tag=0, deny_tag=1, auth_info='默认规则: 禁用setu r18'), AuthNode(node='pixiv.allow_r18', allow_tag=0, deny_tag=1, auth_info='默认规则: 禁用pivix r18') ] for auth_node in default_auth_nodes: - auth = DBAuth(auth_id=group_id, auth_type='group', auth_node=auth_node.node) + auth = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=auth_node.node) res = await auth.set(allow_tag=auth_node.allow_tag, deny_tag=auth_node.deny_tag, auth_info=auth_node.auth_info) if res.error: logger.opt(colors=True).error(f'配置默认权限失败, {auth_node.node}/{group_id}, error: {res.info}') -async def init_user_auth_node(user_id: int): +async def init_user_auth_node(user_id: int, self_bot: DBBot): """ 为好友配置权限节点默认值 """ @@ -424,7 +439,7 @@ class AuthNode: ] for auth_node in default_auth_nodes: - auth = DBAuth(auth_id=user_id, auth_type='user', auth_node=auth_node.node) + auth = DBAuth(self_bot=self_bot, auth_id=user_id, auth_type='user', auth_node=auth_node.node) res = await auth.set(allow_tag=auth_node.allow_tag, deny_tag=auth_node.deny_tag, auth_info=auth_node.auth_info) if res.error: logger.opt(colors=True).error(f'配置默认权限失败, {auth_node.node}/{user_id}, error: {res.info}') diff --git a/omega_miya/plugins/Omega_manage/sys_background_scheduled.py b/omega_miya/plugins/Omega_manager/sys_background_scheduled.py similarity index 70% rename from omega_miya/plugins/Omega_manage/sys_background_scheduled.py rename to omega_miya/plugins/Omega_manager/sys_background_scheduled.py index a404c0dd..18026515 100644 --- a/omega_miya/plugins/Omega_manage/sys_background_scheduled.py +++ b/omega_miya/plugins/Omega_manager/sys_background_scheduled.py @@ -3,10 +3,9 @@ """ import nonebot from nonebot import logger, require -from omega_miya.utils.Omega_Base import DBGroup, DBUser, DBFriend, DBStatus, DBCoolDownEvent, DBTable +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBUser, DBFriend, DBStatus, DBCoolDownEvent from omega_miya.utils.Omega_plugin_utils import HttpFetcher - global_config = nonebot.get_driver().config ENABLE_PROXY = global_config.enable_proxy PROXY_CHECK_URL = global_config.proxy_check_url @@ -32,29 +31,30 @@ # start_date=None, # end_date=None, # timezone=None, - id='refresh_group_info', + id='refresh_groups_info', coalesce=True, misfire_grace_time=300 ) -async def refresh_group_info(): - logger.debug('refresh_group_info: Start task') +async def refresh_groups_info(): + logger.opt(colors=True).info('Refresh groups info | Started all bots groups refreshing tasks...') from nonebot import get_bots for bot_id, bot in get_bots().items(): + self_bot = DBBot(self_qq=int(bot.self_id)) group_list = await bot.call_api('get_group_list') # 首先获取所有群组列表 禁用不在的群组 - t = DBTable(table_name='Group') - exist_group_result = await t.list_col('group_id') + exist_group_result = await DBBotGroup.list_exist_bot_groups(self_bot=self_bot) exist_group_list = [int(x) for x in exist_group_result.result] actual_group_list = [int(x.get('group_id')) for x in group_list] disable_group_list = list(set(exist_group_list).difference(set(actual_group_list))) for group in disable_group_list: - disable_result = await DBGroup(group_id=group).permission_set(notice=-1, command=-1, level=-1) + disable_result = await DBBotGroup(group_id=group, self_bot=self_bot).\ + permission_set(notice=-1, command=-1, level=-1) if disable_result.error: - logger.warning(f'Disable expire group {group} failed, {disable_result.info}') + logger.warning(f'Refresh groups info | Disable expire group {group} failed, {disable_result.info}') # 执行群组信息更新 for group in group_list: @@ -62,24 +62,26 @@ async def refresh_group_info(): # 调用api获取群信息 group_info = await bot.call_api(api='get_group_info', group_id=group_id) group_name = group_info['group_name'] - group = DBGroup(group_id=group_id) + group_memo = group_info.get('group_memo') + group = DBBotGroup(group_id=group_id, self_bot=self_bot) # 更新群信息 - await group.add(name=group_name) + add_group_res = await group.add(name=group_name) + if add_group_res.error: + logger.error(f'Refresh groups info | Add group {group_id} failed, {add_group_res.info}') + continue + set_bot_group_res = await group.set_bot_group(group_memo=group_memo) + if set_bot_group_res.error: + logger.error(f'Refresh groups info | Add group {group_id} failed, {set_bot_group_res.info}') + continue # 更新用户 group_member_list = await bot.call_api(api='get_group_member_list', group_id=group_id) # 首先清除数据库中退群成员 - exist_member_list = [] - for user_info in group_member_list: - user_qq = user_info['user_id'] - exist_member_list.append(int(user_qq)) - - db_member_list = [] + exist_member_list = [int(x.get('user_id')) for x in group_member_list] member_res = await group.member_list() - for user_id, nickname in member_res.result: - db_member_list.append(user_id) + db_member_list = [user_id for user_id, nickname in member_res.result] del_member_list = list(set(db_member_list).difference(set(exist_member_list))) for user_id in del_member_list: @@ -88,23 +90,23 @@ async def refresh_group_info(): # 更新群成员 for user_info in group_member_list: # 用户信息 - user_qq = user_info['user_id'] - user_nickname = user_info['nickname'] - user_group_nickmane = user_info['card'] + user_qq = user_info.get('user_id') + user_nickname = user_info.get('nickname') + user_group_nickmane = user_info.get('card') if not user_group_nickmane: user_group_nickmane = user_nickname _user = DBUser(user_id=user_qq) _result = await _user.add(nickname=user_nickname) if not _result.success(): - logger.warning(f'Refresh group info, User: {user_qq}, {_result.info}') + logger.warning(f'Refresh groups info | Add group user: {user_qq}, {_result.info}') continue _result = await group.member_add(user=_user, user_group_nickname=user_group_nickmane) if not _result.success(): - logger.warning(f'Refresh group info, User: {user_qq}, {_result.info}') + logger.warning(f'Refresh groups info | Upgrade group user: {user_qq}, {_result.info}') await group.init_member_status() - logger.info(f'Refresh group info completed, Bot: {bot_id}, Group: {group_id}') - logger.debug('refresh_group_info: Task finish') + logger.info(f'Refresh groups info | Task completed, Bot: {bot_id}, Group: {group_id}') + logger.opt(colors=True).info('Refresh groups info | All tasks completed') # 创建自动更新好友信息的定时任务 @@ -126,15 +128,16 @@ async def refresh_group_info(): misfire_grace_time=120 ) async def refresh_friends_info(): - logger.debug('refresh_friends_info: Start task') + logger.opt(colors=True).info('Refresh friends info | Started all bots friends refreshing tasks...') from nonebot import get_bots for bot_id, bot in get_bots().items(): + self_bot = DBBot(self_qq=int(bot.self_id)) friends_list = await bot.call_api('get_friend_list') # 首先清除非好友 - exist_friend_result = await DBFriend.list_exist_friends() + exist_friend_result = await DBFriend.list_exist_friends(self_bot=self_bot) if exist_friend_result.error: - logger.error(f'Refresh friends info failed, get exist friend list failed: {exist_friend_result.info}') + logger.error(f'Refresh friends info | Getting exist friend list failed: {exist_friend_result.info}') return exist_friend_list = exist_friend_result.result @@ -142,9 +145,9 @@ async def refresh_friends_info(): del_member_list = list(set(exist_friend_list).difference(set(actual_friend_list))) for user in del_member_list: - del_result = await DBFriend(user_id=user).del_friend() + del_result = await DBFriend(user_id=user, self_bot=self_bot).del_friend() if del_result.error: - logger.warning(f'Del expire friend user {user} failed, {del_result.info}') + logger.warning(f'Refresh friends info | Del expire friend user {user} failed, {del_result.info}') # 更新好友信息 for friend in friends_list: @@ -152,22 +155,21 @@ async def refresh_friends_info(): nickname = friend.get('nickname') remark = friend.get('remark') - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) # 更新用户表 add_user_result = await friend.add(nickname=nickname) if add_user_result.error: - logger.warning(f'Add friend user {user_id} failed, {add_user_result.info}') + logger.error(f'Refresh friends info | Add user {user_id} failed, {add_user_result.info}') continue # 更新好友表 add_friend_result = await friend.set_friend(nickname=nickname, remark=remark) if add_friend_result.error: - logger.warning(f'Add friend user {user_id} failed, {add_user_result.info}') - - logger.debug(f'Refresh friends info, upgrade friend user {user_id} info') + logger.error(f'Refresh friends info | Add friend user {user_id} failed, {add_user_result.info}') + logger.debug(f'Refresh friends info | Upgrade friend user {user_id} info') - logger.info(f'Refresh friends info completed, Bot: {bot_id}') - logger.debug('refresh_friends_info: Task finish') + logger.info(f'Refresh friends info | Task completed, Bot: {bot_id}') + logger.opt(colors=True).info('Refresh friends info | All tasks completed') # 创建用于刷新冷却事件的定时任务 @@ -178,9 +180,9 @@ async def refresh_friends_info(): # day='*/1', # week=None, # day_of_week=None, - # hour='*/8', + hour='*/12', # minute='*/1', - second='*/20', + # second='*/20', # start_date=None, # end_date=None, # timezone=None, @@ -190,7 +192,7 @@ async def refresh_friends_info(): ) async def cool_down_refresh(): await DBCoolDownEvent.clear_time_out_event() - logger.debug('cool_down_refresh: cleaning time out event') + logger.opt(colors=True).info('Cool down refresh | Cleaned all expired event') # 创建用于检查代理可用性的状态的定时任务 diff --git a/omega_miya/plugins/Omega_recaller/__init__.py b/omega_miya/plugins/Omega_recaller/__init__.py new file mode 100644 index 00000000..edef843d --- /dev/null +++ b/omega_miya/plugins/Omega_recaller/__init__.py @@ -0,0 +1,248 @@ +""" +@Author : Ailitonia +@Date : 2021/07/17 22:36 +@FileName : Omega_recaller.py +@Project : nonebot2_miya +@Description : 自助撤回群内消息 需bot为管理员 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from typing import Dict +from datetime import datetime +from nonebot import export, logger +from nonebot.plugin import CommandGroup +from nonebot.typing import T_State +from nonebot.exception import FinishedException +from nonebot.permission import SUPERUSER +from nonebot.adapters.cqhttp.permission import GROUP, GROUP_OWNER, GROUP_ADMIN +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import GroupMessageEvent +from nonebot.adapters.cqhttp.message import Message, MessageSegment +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBAuth, DBHistory +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PermissionChecker + + +# Custom plugin usage text +__plugin_raw_name__ = __name__.split('.')[-1] +__plugin_name__ = '自助撤回' +__plugin_usage__ = r'''【自助撤回】 +让非管理员自助撤回群消息 +Bot得是管理员才行 + +**Permission** +AuthNode + +**AuthNode** +basic + +**Usage** +回复需撤回的消息 +/撤回 + +**GroupAdmin and SuperUser Only** +/启用撤回 [@用户] +/禁用撤回 [@用户] +''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + +# 存放bot在群组的身份 +BOT_ROLE: Dict[int, str] = {} + +# 存放bot群组信息过期时间 +BOT_ROLE_EXPIRED: Dict[int, datetime] = {} + + +# 注册事件响应器 +SelfHelpRecall = CommandGroup( + 'SelfHelpRecall', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='search_anime', + command=True, + level=10, + auth_node='basic'), + permission=SUPERUSER | GROUP, + priority=10, + block=True +) + + +recall = SelfHelpRecall.command('recall', aliases={'撤回'}) + + +@recall.handle() +async def handle_super_recall_self_msg(bot: Bot, event: GroupMessageEvent, state: T_State): + # 特别处理管理员撤回bot发送的消息 + if event.reply and str(event.user_id) in bot.config.superusers: + if event.reply.sender.user_id == event.self_id: + recall_msg_id = event.reply.message_id + try: + await bot.delete_msg(message_id=recall_msg_id) + logger.info(f'Self-help Recall | 管理员 {event.group_id}/{event.user_id} ' + f'撤回了Bot消息: {recall_msg_id}, "{event.reply.message}"') + await recall.finish() + except FinishedException: + raise FinishedException + except Exception as e: + logger.error(f'Self-help Recall | 管理员 {event.group_id}/{event.user_id} ' + f'撤回Bot消息失败, error: {repr(e)}') + msg = f'{MessageSegment.at(user_id=event.user_id)}撤回消息部分或全部失败了QAQ' + await recall.finish(Message(msg)) + + +@recall.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + if event.sender.role in ['owner', 'admin']: + await recall.finish('您已经是群管理了, 请您自行去撤回消息OvO') + + # 用户权限检查 + auth_check_result = await PermissionChecker(self_bot=DBBot(self_qq=event.self_id)).check_auth_node( + auth_id=event.group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic.{event.user_id}') + if auth_check_result != 1: + await recall.finish(Message(f'{MessageSegment.at(user_id=event.user_id)}你没有撤回消息的权限QAQ')) + + global BOT_ROLE + global BOT_ROLE_EXPIRED + # 判断bot身份和过期时间 + bot_role = BOT_ROLE.get(event.group_id) + bot_role_expired = BOT_ROLE_EXPIRED.get(event.group_id) + if not bot_role_expired: + bot_role_expired = datetime.now() + BOT_ROLE_EXPIRED.update({event.group_id: bot_role_expired}) + # 默认过期时间为 21600 秒 (6 小时) + is_role_expired = (datetime.now() - bot_role_expired).total_seconds() > 21600 + if is_role_expired or not bot_role: + bot_role = (await bot.get_group_member_info(group_id=event.group_id, user_id=event.self_id)).get('role') + BOT_ROLE.update({event.group_id: bot_role}) + + if bot_role not in ['owner', 'admin']: + await recall.finish('Bot非群管理员, 无法执行撤回操作QAQ') + + error_tag: bool = False + # 提取引用消息 + if event.reply: + # 同时撤回被引用的消息 + recall_msg_id = event.reply.message_id + try: + await bot.delete_msg(message_id=recall_msg_id) + logger.info( + f'Self-help Recall | {event.group_id}/{event.user_id} 撤回消息: {recall_msg_id}, "{event.reply.message}"') + except Exception as e: + error_tag = True + logger.error(f'Self-help Recall | {event.group_id}/{event.user_id} 撤回引用消息失败, error: {repr(e)}') + + # 同时撤回和当前执行撤回人的消息 + command_msg_id = event.message_id + try: + await bot.delete_msg(message_id=command_msg_id) + logger.debug(f'Self-help Recall | {event.group_id}/{event.user_id} 撤回执行消息: {command_msg_id}') + except Exception as e: + error_tag = True + logger.error(f'Self-help Recall | {event.group_id}/{event.user_id} 撤回当前执行消息失败, error: {repr(e)}') + + history_result = await DBHistory( + time=event.time, self_id=event.self_id, post_type='Self-help Recall', detail_type='member-recall').add( + sub_type=f'error:{error_tag}', event_id=event.message_id, group_id=event.group_id, user_id=event.user_id, + user_name=event.sender.nickname, + raw_data=f'Operator: {event.group_id}/{event.user_id}; RecalledMsg, user_id: {event.reply.sender.user_id}, ' + f'msg_id: {recall_msg_id}, {event.reply.message}', + msg_data=f'群: {event.group_id}, 用户: {event.user_id}/{event.sender.card}/{event.sender.nickname}, ' + f'撤回了用户 {event.reply.sender.user_id}/{event.reply.sender.nickname} 的一条消息: {recall_msg_id}, ' + f'被撤回消息内容: {event.reply.message}' + ) + if history_result.error: + logger.error(f'Self-help Recall | 记录撤回历史失败, error: {history_result.info}') + + if error_tag: + msg = f'{MessageSegment.at(user_id=event.user_id)}撤回消息部分或全部失败了QAQ' + await recall.finish(Message(msg)) + else: + msg = f'{MessageSegment.at(user_id=event.user_id)}你撤回了' \ + f'{MessageSegment.at(user_id=event.reply.sender.user_id)}的一条消息' + await recall.finish(Message(msg)) + else: + await recall.finish('没有引用需要撤回的消息! 请回复需要撤回的消息后发送“/撤回”') + + +recall_allow = SelfHelpRecall.command( + 'recall_allow', aliases={'启用撤回'}, permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN) + + +@recall_allow.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().split() + if not args: + pass + else: + await recall_allow.finish('参数错误QAQ, 请在 “/启用撤回” 命令后直接@对应用户') + + # 处理@人 qq在at别人时后面会自动加空格 + if len(event.message) in [1, 2]: + if event.message[0].type == 'at': + at_qq = event.message[0].data.get('qq') + if at_qq: + self_bot = DBBot(self_qq=event.self_id) + group = DBBotGroup(group_id=event.group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + logger.error(f'Self-help Recall | 启用用户撤回失败, 数据库没有对应群组: {event.group_id}') + await recall_allow.finish('发生了意外的错误QAQ, 请联系管理员处理') + + auth_node = DBAuth(self_bot=self_bot, auth_id=event.group_id, auth_type='group', + auth_node=f'{__plugin_raw_name__}.basic.{at_qq}') + result = await auth_node.set(allow_tag=1, deny_tag=0, auth_info='启用自助撤回') + if result.success(): + logger.info(f'Self-help Recall | {event.group_id}/{event.user_id} 已启用用户 {at_qq} 撤回权限') + await recall_allow.finish(f'已启用用户{at_qq}撤回权限') + else: + logger.error(f'Self-help Recall | {event.group_id}/{event.user_id} 启用用户 {at_qq} 撤回权限失败, ' + f'error: {result.info}') + await recall_allow.finish(f'启用用户撤回权限失败QAQ, 请联系管理员处理') + + await recall_allow.finish('没有指定用户QAQ, 请在 “/启用撤回” 命令后直接@对应用户') + + +recall_deny = SelfHelpRecall.command( + 'recall_deny', aliases={'禁用撤回'}, permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN) + + +@recall_deny.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().split() + if not args: + pass + else: + await recall_deny.finish('参数错误QAQ, 请在 “/禁用撤回” 命令后直接@对应用户') + + # 处理@人 qq在at别人时后面会自动加空格 + if len(event.message) in [1, 2]: + if event.message[0].type == 'at': + at_qq = event.message[0].data.get('qq') + if at_qq: + self_bot = DBBot(self_qq=event.self_id) + group = DBBotGroup(group_id=event.group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + logger.error(f'Self-help Recall | 禁用用户撤回失败, 数据库没有对应群组: {event.group_id}') + await recall_deny.finish('发生了意外的错误QAQ, 请联系管理员处理') + + auth_node = DBAuth(self_bot=self_bot, auth_id=event.group_id, auth_type='group', + auth_node=f'{__plugin_raw_name__}.basic.{at_qq}') + result = await auth_node.set(allow_tag=0, deny_tag=1, auth_info='禁用自助撤回') + if result.success(): + logger.info(f'Self-help Recall | {event.group_id}/{event.user_id} 已禁用用户 {at_qq} 撤回权限') + await recall_deny.finish(f'已禁用用户{at_qq}撤回权限') + else: + logger.error(f'Self-help Recall | {event.group_id}/{event.user_id} 禁用用户 {at_qq} 撤回权限失败, ' + f'error: {result.info}') + await recall_deny.finish(f'禁用用户撤回权限失败QAQ, 请联系管理员处理') + + await recall_deny.finish('没有指定用户QAQ, 请在 “/禁用撤回” 命令后直接@对应用户') diff --git a/omega_miya/plugins/Omega_sign_in/__init__.py b/omega_miya/plugins/Omega_sign_in/__init__.py new file mode 100644 index 00000000..fb0b6044 --- /dev/null +++ b/omega_miya/plugins/Omega_sign_in/__init__.py @@ -0,0 +1,134 @@ +""" +@Author : Ailitonia +@Date : 2021/07/17 1:29 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : 轻量化签到插件 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import CommandGroup, logger, export, get_driver +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import GroupMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state +from omega_miya.utils.Omega_Base import DBUser +from .config import Config + + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +FAVORABILITY_ALIAS = plugin_config.favorability_alias +ENERGY_ALIAS = plugin_config.energy_alias +CURRENCY_ALIAS = plugin_config.currency_alias + + +class SignInException(Exception): + pass + + +class DuplicateException(SignInException): + pass + + +class FailedException(SignInException): + pass + + +# Custom plugin usage text +__plugin_name__ = '签到' +__plugin_usage__ = r'''【Omega 签到插件】 +轻量化签到插件 +好感度系统基础支持 +仅限群聊使用 + +**Permission** +Command & Lv.10 +or AuthNode + +**AuthNode** +basic + +**Usage** +/签到''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +SignIn = CommandGroup( + 'SignIn', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='sign_in', + command=True, + level=10, + auth_node='basic'), + permission=GROUP, + priority=10, + block=True) + +sign_in = SignIn.command('sign_in', aliases={'签到'}) + + +@sign_in.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + user = DBUser(user_id=event.user_id) + + try: + # 尝试签到 + sign_in_result = await user.sign_in() + if sign_in_result.error: + raise FailedException(f'签到失败, {sign_in_result.info}') + elif sign_in_result.result == 1: + raise DuplicateException('重复签到') + + # 查询连续签到时间 + sign_in_c_d_result = await user.sign_in_continuous_days() + if sign_in_c_d_result.error: + raise FailedException(f'查询连续签到时间失败, {sign_in_c_d_result.info}') + continuous_days = sign_in_c_d_result.result + + # 尝试为用户增加好感度 + if continuous_days < 7: + favorability_inc = 10 + currency_inc = 1 + elif continuous_days < 30: + favorability_inc = 30 + currency_inc = 2 + else: + favorability_inc = 50 + currency_inc = 5 + + favorability_result = await user.favorability_add(favorability=favorability_inc, currency=currency_inc) + if favorability_result.error and favorability_result.info == 'NoResultFound': + favorability_result = await user.favorability_reset(favorability=favorability_inc, currency=currency_inc) + if favorability_result.error: + raise FailedException(f'增加好感度失败, {favorability_result.info}') + + # 获取当前好感度信息 + favorability_status_result = await user.favorability_status() + if favorability_status_result.error: + raise FailedException(f'获取好感度信息失败, {favorability_status_result}') + + status, mood, favorability, energy, currency, response_threshold = favorability_status_result.result + + msg = f'签到成功! {FAVORABILITY_ALIAS}+{favorability_inc}, {CURRENCY_ALIAS}+{currency_inc}!\n\n' \ + f'你已连续签到{continuous_days}天\n' \ + f'当前{FAVORABILITY_ALIAS}: {favorability}\n' \ + f'当前{CURRENCY_ALIAS}: {currency}' + logger.info(f'{event.group_id}/{event.user_id} 签到成功') + await sign_in.finish(msg) + except DuplicateException as e: + logger.info(f'{event.group_id}/{event.user_id} 重复签到, {str(e)}') + await sign_in.finish('你今天已经签到过了, 请明天再来吧~') + except FailedException as e: + logger.error(f'{event.group_id}/{event.user_id} 签到失败, {str(e)}') + await sign_in.finish('签到失败了QAQ, 请稍后再试或联系管理员处理') diff --git a/omega_miya/plugins/Omega_sign_in/config.py b/omega_miya/plugins/Omega_sign_in/config.py new file mode 100644 index 00000000..e218d0cd --- /dev/null +++ b/omega_miya/plugins/Omega_sign_in/config.py @@ -0,0 +1,22 @@ +""" +@Author : Ailitonia +@Date : 2021/07/17 2:04 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + + # plugin custom config + favorability_alias: str = '好感度' + energy_alias: str = '能量值' + currency_alias: str = '金币' + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/Omega_skill/__init__.py b/omega_miya/plugins/Omega_skill/__init__.py index 7294bed0..85d420c6 100644 --- a/omega_miya/plugins/Omega_skill/__init__.py +++ b/omega_miya/plugins/Omega_skill/__init__.py @@ -4,7 +4,7 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent from nonebot.adapters.cqhttp.permission import GROUP -from omega_miya.utils.Omega_Base import DBSkill, DBUser, DBTable, Result +from omega_miya.utils.Omega_Base import DBSkill, DBUser, Result from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state # Custom plugin usage text @@ -206,26 +206,25 @@ async def handle_sub_command(bot: Bot, event: GroupMessageEvent, state: T_State) await skill_group_user.finish('没有这个命令哦QAQ') result = await command[sub_command](bot=bot, event=event, state=state) if result.success(): - logger.info(f"Group: {event.group_id}, {sub_command}, Success, {result.info}") + logger.info(f"Group: {event.group_id}, User: {event.user_id}, {sub_command}, Success, {result.info}") if sub_command in need_reply: await skill_group_user.finish(result.result) else: await skill_group_user.finish('Success') else: - logger.error(f"Group: {event.group_id}, {sub_command}, Failed, {result.info}") + logger.error(f"Group: {event.group_id}, User: {event.user_id}, {sub_command}, Failed, {result.info}") await skill_group_user.finish('Failed QAQ') async def skill_list(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.TextResult: - skill_table = DBTable(table_name='Skill') - _res = await skill_table.list_col(col_name='name') - if _res.success(): + skill_res = await DBSkill.list_available_skill() + if skill_res.success(): msg = '目前已有的技能列表如下:' - for skill_name in _res.result: + for skill_name in skill_res.result: msg += f'\n{skill_name}' - result = Result.TextResult(False, _res.info, msg) + result = Result.TextResult(False, skill_res.info, msg) else: - result = Result.TextResult(True, _res.info, '') + result = Result.TextResult(True, skill_res.info, '') return result diff --git a/omega_miya/plugins/Omega_vocation/__init__.py b/omega_miya/plugins/Omega_vacation/__init__.py similarity index 75% rename from omega_miya/plugins/Omega_vocation/__init__.py rename to omega_miya/plugins/Omega_vacation/__init__.py index 522aff84..4906036a 100644 --- a/omega_miya/plugins/Omega_vocation/__init__.py +++ b/omega_miya/plugins/Omega_vacation/__init__.py @@ -5,8 +5,8 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import GroupMessageEvent from nonebot.adapters.cqhttp.permission import GROUP -from omega_miya.utils.Omega_Base import DBSkill, DBUser, DBGroup, DBTable -from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, check_auth_node +from omega_miya.utils.Omega_Base import DBSkill, DBUser, DBBot, DBBotGroup +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PermissionChecker # Custom plugin usage text __plugin_name__ = '请假' @@ -39,18 +39,18 @@ init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 -vocation = MatcherGroup( +vacation = MatcherGroup( type='message', # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 state=init_permission_state( - name='vocation', + name='vacation', command=True, auth_node='basic'), permission=GROUP, priority=10, block=True) -my_status = vocation.on_command('我的状态') +my_status = vacation.on_command('我的状态') @my_status.handle() @@ -74,7 +74,7 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat # 注册事件响应器 -reset_status = vocation.on_command('重置状态', aliases={'销假'}) +reset_status = vacation.on_command('重置状态', aliases={'销假'}) @reset_status.handle() @@ -91,60 +91,60 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat # 注册事件响应器 -my_vocation = vocation.on_command('我的假期') +my_vacation = vacation.on_command('我的假期') -@my_vocation.handle() +@my_vacation.handle() async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): user_id = event.user_id user = DBUser(user_id=user_id) - result = await user.vocation_status() + result = await user.vacation_status() if result.success(): status, stop_time = result.result if status == 1: msg = f'你的假期将持续到: 【{stop_time}】' else: msg = '你似乎并不在假期中呢~需要现在请假吗?' - logger.info(f"my_vocation: {event.group_id}/{user_id}, Success, {result.info}") + logger.info(f"my_vacation: {event.group_id}/{user_id}, Success, {result.info}") await my_status.finish(msg) else: - logger.error(f"my_vocation: {event.group_id}/{user_id}, Failed, {result.info}") + logger.error(f"my_vacation: {event.group_id}/{user_id}, Failed, {result.info}") await my_status.finish('没有查询到你的假期信息QAQ, 请尝试使用【/重置状态】来解决问题') # 注册事件响应器 -set_vocation = vocation.on_command('请假') +set_vacation = vacation.on_command('请假') # 修改默认参数处理 -@set_vocation.args_parser +@set_vacation.args_parser async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): args = str(event.get_plaintext()).strip().lower().split() if not args: - await set_vocation.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + await set_vacation.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') state[state["_current_key"]] = args[0] if state[state["_current_key"]] == '取消': - await set_vocation.finish('操作已取消') + await set_vacation.finish('操作已取消') -@set_vocation.handle() +@set_vacation.handle() async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): args = str(event.get_plaintext()).strip().lower().split() if not args: pass elif args and len(args) == 1: - state['vocation_time'] = args[0] + state['vacation_time'] = args[0] state['reason'] = None elif args and len(args) == 2: - state['vocation_time'] = args[0] + state['vacation_time'] = args[0] state['reason'] = args[1] else: - await set_vocation.finish('参数错误QAQ') + await set_vacation.finish('参数错误QAQ') -@set_vocation.got('vocation_time', prompt='请输入你想请假的时间: \n仅支持数字+周/天/小时/分/分钟/秒') -async def handle_vocation_time(bot: Bot, event: GroupMessageEvent, state: T_State): - time = state['vocation_time'] +@set_vacation.got('vacation_time', prompt='请输入你想请假的时间: \n仅支持数字+周/天/小时/分/分钟/秒') +async def handle_vacation_time(bot: Bot, event: GroupMessageEvent, state: T_State): + time = state['vacation_time'] add_time = timedelta() if re.match(r'^\d+周$', time): time = int(re.sub(r'周$', '', time)) @@ -162,28 +162,28 @@ async def handle_vocation_time(bot: Bot, event: GroupMessageEvent, state: T_Stat time = int(re.sub(r'秒$', '', time)) add_time = timedelta(seconds=time) else: - await set_vocation.reject('仅支持数字+周/天/小时/分/分钟/秒, 请重新输入, 取消命令请发送【取消】:') + await set_vacation.reject('仅支持数字+周/天/小时/分/分钟/秒, 请重新输入, 取消命令请发送【取消】:') state['stop_at'] = datetime.now() + add_time -@set_vocation.got('stop_at', prompt='stop_at?') -@set_vocation.got('reason', prompt='请输入你的请假理由:') -async def handle_vocation_stop(bot: Bot, event: GroupMessageEvent, state: T_State): +@set_vacation.got('stop_at', prompt='stop_at?') +@set_vacation.got('reason', prompt='请输入你的请假理由:') +async def handle_vacation_stop(bot: Bot, event: GroupMessageEvent, state: T_State): user_id = event.user_id user = DBUser(user_id=user_id) stop_at = state['stop_at'] reason = state['reason'] - result = await user.vocation_set(stop_time=stop_at, reason=reason) + result = await user.vacation_set(stop_time=stop_at, reason=reason) if result.success(): - logger.info(f"Group: {event.group_id}/{user_id}, set_vocation, Success, {result.info}") - await set_vocation.finish(f'请假成功! 你的假期将持续到【{stop_at.strftime("%Y-%m-%d %H:%M:%S")}】') + logger.info(f"Group: {event.group_id}/{user_id}, set_vacation, Success, {result.info}") + await set_vacation.finish(f'请假成功! 你的假期将持续到【{stop_at.strftime("%Y-%m-%d %H:%M:%S")}】') else: - logger.error(f"Group: {event.group_id}/{user_id}, set_vocation, Failed, {result.info}") - await set_vocation.finish('请假失败, 发生了意外的错误QAQ') + logger.error(f"Group: {event.group_id}/{user_id}, set_vacation, Failed, {result.info}") + await set_vacation.finish('请假失败, 发生了意外的错误QAQ') # 注册事件响应器 -get_idle = vocation.on_command('谁有空') +get_idle = vacation.on_command('谁有空') # 修改默认参数处理 @@ -205,14 +205,15 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat elif args and len(args) == 1: state['skill'] = args[0] else: - await set_vocation.finish('参数错误QAQ') + await set_vacation.finish('参数错误QAQ') @get_idle.got('skill', prompt='空闲技能组?') async def handle_skill(bot: Bot, event: GroupMessageEvent, state: T_State): skill = state['skill'] group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) if not skill: result = await group.idle_member_list() if result.success() and result.result: @@ -228,10 +229,8 @@ async def handle_skill(bot: Bot, event: GroupMessageEvent, state: T_State): logger.error(f"Group: {event.group_id}, get_idle, Failed, {result.info}") await get_idle.finish(f'似乎发生了点错误QAQ') else: - skill_table = DBTable(table_name='Skill') - skill_res = await skill_table.list_col(col_name='name') - exist_skill = [x for x in skill_res.result] - if skill not in exist_skill: + skill_res = await DBSkill.list_available_skill() + if skill not in skill_res.result: await get_idle.reject(f'没有{skill}这个技能, 请重新输入, 取消命令请发送【取消】:') result = await group.idle_skill_list(skill=DBSkill(name=skill)) if result.success() and result.result: @@ -248,25 +247,26 @@ async def handle_skill(bot: Bot, event: GroupMessageEvent, state: T_State): # 注册事件响应器 -get_vocation = vocation.on_command('谁在休假') +get_vacation = vacation.on_command('谁在休假') -@get_vocation.handle() +@get_vacation.handle() async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): group_id = event.group_id - group = DBGroup(group_id=group_id) - result = await group.vocation_member_list() + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + result = await group.vacation_member_list() if result.success() and result.result: msg = '' for nickname, stop_at in result.result: msg += f'\n【{nickname}/休假到: {stop_at}】' - logger.info(f"Group: {event.group_id}, get_vocation, Success, {result.info}") - await get_vocation.finish(f'现在在休假的的人: \n{msg}') + logger.info(f"Group: {event.group_id}, get_vacation, Success, {result.info}") + await get_vacation.finish(f'现在在休假的的人: \n{msg}') elif result.success() and not result.result: - logger.info(f"Group: {event.group_id}, get_vocation, Success, {result.info}") - await get_vocation.finish(f'现在似乎没没有人休假呢~') + logger.info(f"Group: {event.group_id}, get_vacation, Success, {result.info}") + await get_vacation.finish(f'现在似乎没没有人休假呢~') else: - logger.error(f"Group: {event.group_id}, get_vocation, Failed, {result.info}") + logger.error(f"Group: {event.group_id}, get_vacation, Failed, {result.info}") await get_idle.finish(f'似乎发生了点错误QAQ') @@ -287,16 +287,16 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat # start_date=None, # end_date=None, # timezone=None, - id='member_vocations_monitor', + id='member_vacations_monitor', coalesce=True, misfire_grace_time=60 ) -async def member_vocations_monitor(): - logger.debug(f"member_vocations_monitor: vocation checking started") +async def member_vacations_monitor(): + logger.debug(f"member_vacations_monitor: vacation checking started") from nonebot import get_bots - over_vocation_user = set() + over_vacation_user = set() for bot_id, bot in get_bots().items(): group_list = await bot.call_api('get_group_list') @@ -304,11 +304,12 @@ async def member_vocations_monitor(): group_id = group.get('group_id') # 跳过不具备权限的组 - auth_check_res = await check_auth_node( - auth_id=group_id, auth_type='group', auth_node='Omega_vocation.basic') + self_bot = DBBot(self_qq=int(bot.self_id)) + auth_check_res = await PermissionChecker(self_bot=self_bot).check_auth_node( + auth_id=group_id, auth_type='group', auth_node='Omega_vacation.basic') if auth_check_res != 1: continue - logger.debug(f"member_vocations_monitor: checking group: {group_id}") + logger.debug(f"member_vacations_monitor: checking group: {group_id}") # 调用api获取群成员信息 group_member_list = await bot.call_api(api='get_group_member_list', group_id=group_id) @@ -319,14 +320,14 @@ async def member_vocations_monitor(): user_nickname = user_info['nickname'] user_qq = user_info['user_id'] user = DBUser(user_id=user_qq) - user_vocation_res = await user.vocation_status() - status, stop_time = user_vocation_res.result + user_vacation_res = await user.vacation_status() + status, stop_time = user_vacation_res.result if status == 1 and datetime.now() >= stop_time: msg = f'【{user_nickname}】的假期已经结束啦~\n快给他/她安排工作吧!' await bot.call_api(api='send_group_msg', group_id=group_id, message=msg) - over_vocation_user.add(user) - for user in over_vocation_user: + over_vacation_user.add(user) + for user in over_vacation_user: _res = await user.status_set(status=0) if not _res.success(): logger.error(f"reset user status failed: {_res.info}") - logger.debug('member_vocations_monitor: vocation checking completed') + logger.debug('member_vacations_monitor: vacation checking completed') diff --git a/omega_miya/plugins/announce/__init__.py b/omega_miya/plugins/announce/__init__.py index 2431d4bc..81305ab9 100644 --- a/omega_miya/plugins/announce/__init__.py +++ b/omega_miya/plugins/announce/__init__.py @@ -5,7 +5,7 @@ from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import PrivateMessageEvent -from omega_miya.utils.Omega_Base import DBTable +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup # 注册事件响应器 @@ -15,7 +15,7 @@ # 修改默认参数处理 @announce.args_parser async def parse(bot: Bot, event: PrivateMessageEvent, state: T_State): - args = str(event.get_plaintext()).strip().lower() + args = str(event.get_plaintext()).strip() if not args: await announce.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') state[state["_current_key"]] = args @@ -42,37 +42,37 @@ async def handle_first_receive(bot: Bot, event: PrivateMessageEvent, state: T_St async def handle_announce(bot: Bot, event: PrivateMessageEvent, state: T_State): group = state['group'] msg = state['announce_text'] + self_bot = DBBot(self_qq=int(bot.self_id)) if group == 'all': - t = DBTable(table_name='Group') - group_res = await t.list_col(col_name='group_id') + group_res = await DBBotGroup.list_exist_bot_groups(self_bot=self_bot) for group_id in group_res.result: try: - await bot.call_api(api='send_group_msg', group_id=group_id, message=msg) + await bot.send_group_msg(group_id=group_id, message=msg) except Exception as e: logger.warning(f'向群组发送公告失败, group: {group_id}, error: {repr(e)}') continue elif group == 'notice': - t = DBTable(table_name='Group') - group_res = await t.list_col_with_condition('group_id', 'notice_permissions', 1) + group_res = await DBBotGroup.list_exist_bot_groups_by_notice_permissions( + notice_permissions=1, self_bot=self_bot) for group_id in group_res.result: try: - await bot.call_api(api='send_group_msg', group_id=group_id, message=msg) + await bot.send_group_msg(group_id=group_id, message=msg) except Exception as e: logger.warning(f'向群组发送公告失败, group: {group_id}, error: {repr(e)}') continue elif group == 'command': - t = DBTable(table_name='Group') - group_res = await t.list_col_with_condition('group_id', 'command_permissions', 1) + group_res = await DBBotGroup.list_exist_bot_groups_by_command_permissions( + command_permissions=1, self_bot=self_bot) for group_id in group_res.result: try: - await bot.call_api(api='send_group_msg', group_id=group_id, message=msg) + await bot.send_group_msg(group_id=group_id, message=msg) except Exception as e: logger.warning(f'向群组发送公告失败, group: {group_id}, error: {repr(e)}') continue elif re.match(r'^\d+$', group): group_id = int(group) try: - await bot.call_api(api='send_group_msg', group_id=group_id, message=msg) + await bot.send_group_msg(group_id=group_id, message=msg) except Exception as e: logger.warning(f'向群组发送公告失败, group: {group_id}, error: {repr(e)}') else: diff --git a/omega_miya/plugins/bilibili_dynamic_monitor/__init__.py b/omega_miya/plugins/bilibili_dynamic_monitor/__init__.py index 57723804..f111822e 100644 --- a/omega_miya/plugins/bilibili_dynamic_monitor/__init__.py +++ b/omega_miya/plugins/bilibili_dynamic_monitor/__init__.py @@ -5,10 +5,10 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER, PRIVATE_FRIEND -from omega_miya.utils.Omega_Base import DBGroup, DBFriend, DBSubscription, Result +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBFriend, DBSubscription, Result from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state from omega_miya.utils.bilibili_utils import BiliUser -from .monitor import scheduler +from .monitor import init_user_dynamic, scheduler # Custom plugin usage text @@ -20,6 +20,10 @@ **Permission** Friend Private Command & Lv.20 +or AuthNode + +**AuthNode** +basic **Usage** **GroupAdmin and SuperUser Only** @@ -28,8 +32,13 @@ /B站动态 清空订阅 /B站动态 订阅列表''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 @@ -40,7 +49,8 @@ state=init_permission_state( name='bilibili_dynamic', command=True, - level=20), + level=20, + auth_node='basic'), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER | PRIVATE_FRIEND, priority=20, block=True) @@ -146,44 +156,53 @@ async def handle_check(bot: Bot, event: MessageEvent, state: T_State): await bilibili_dynamic.finish(f'{sub_command}失败了QAQ, 可能并未订阅该用户, 或请稍后再试~') -async def sub_list(bot: Bot, event: MessageEvent, state: T_State) -> Result.ListResult: +async def sub_list(bot: Bot, event: MessageEvent, state: T_State) -> Result.TupleListResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) result = await group.subscription_list_by_type(sub_type=2) return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) result = await friend.subscription_list_by_type(sub_type=2) return result else: - return Result.ListResult(error=True, info='Illegal event', result=[]) + return Result.TupleListResult(error=True, info='Illegal event', result=[]) async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) + uid = state['uid'] + sub = DBSubscription(sub_type=2, sub_id=uid) + need_init = not (await sub.exist()) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) - uid = state['uid'] - sub = DBSubscription(sub_type=2, sub_id=uid) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) _res = await sub.add(up_name=state.get('up_name'), live_info='B站动态') if not _res.success(): return _res - _res = await group.subscription_add(sub=sub) + # 初次订阅时立即刷新, 避免订阅后发送旧动态的问题 + if need_init: + await bot.send(event=event, message='初次订阅, 正在初始化动态信息, 请稍后...') + await init_user_dynamic(user_id=uid) + _res = await group.subscription_add(sub=sub, group_sub_info='B站动态') if not _res.success(): return _res result = Result.IntResult(error=False, info='Success', result=0) return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) - uid = state['uid'] - sub = DBSubscription(sub_type=2, sub_id=uid) + friend = DBFriend(user_id=user_id, self_bot=self_bot) _res = await sub.add(up_name=state.get('up_name'), live_info='B站动态') if not _res.success(): return _res - _res = await friend.subscription_add(sub=sub) + # 初次订阅时立即刷新, 避免订阅后发送旧动态的问题 + if need_init: + await bot.send(event=event, message='初次订阅, 正在初始化动态信息, 请稍后...') + await init_user_dynamic(user_id=uid) + _res = await friend.subscription_add(sub=sub, user_sub_info='B站动态') if not _res.success(): return _res result = Result.IntResult(error=False, info='Success', result=0) @@ -193,9 +212,10 @@ async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) uid = state['uid'] _res = await group.subscription_del(sub=DBSubscription(sub_type=2, sub_id=uid)) if not _res.success(): @@ -204,7 +224,7 @@ async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) uid = state['uid'] _res = await friend.subscription_del(sub=DBSubscription(sub_type=2, sub_id=uid)) if not _res.success(): @@ -216,9 +236,10 @@ async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe async def sub_clear(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) _res = await group.subscription_clear_by_type(sub_type=2) if not _res.success(): return _res @@ -226,7 +247,7 @@ async def sub_clear(bot: Bot, event: MessageEvent, state: T_State) -> Result.Int return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) _res = await friend.subscription_clear_by_type(sub_type=2) if not _res.success(): return _res diff --git a/omega_miya/plugins/bilibili_dynamic_monitor/monitor.py b/omega_miya/plugins/bilibili_dynamic_monitor/monitor.py index c1edf5f3..a9a5ab52 100644 --- a/omega_miya/plugins/bilibili_dynamic_monitor/monitor.py +++ b/omega_miya/plugins/bilibili_dynamic_monitor/monitor.py @@ -1,9 +1,12 @@ import asyncio import random +from typing import List from nonebot import logger, require, get_bots, get_driver from nonebot.adapters.cqhttp import MessageSegment -from omega_miya.utils.Omega_Base import DBFriend, DBSubscription, DBDynamic, DBTable +from nonebot.adapters.cqhttp.bot import Bot +from omega_miya.utils.Omega_Base import DBSubscription, DBDynamic from omega_miya.utils.bilibili_utils import BiliUser, BiliDynamic, BiliRequestUtils +from omega_miya.utils.Omega_plugin_utils import MsgSender from .config import Config @@ -39,8 +42,7 @@ ) async def dynamic_db_upgrade(): logger.debug('dynamic_db_upgrade: started upgrade subscription info') - t = DBTable(table_name='Subscription') - sub_res = await t.list_col_with_condition('sub_id', 'sub_type', 2) + sub_res = await DBSubscription.list_sub_by_type(sub_type=2) for sub_id in sub_res.result: sub = DBSubscription(sub_type=2, sub_id=sub_id) user_info_result = await BiliUser(user_id=sub_id).get_info() @@ -55,195 +57,181 @@ async def dynamic_db_upgrade(): logger.debug('dynamic_db_upgrade: upgrade subscription info completed') -# 创建动态检查函数 +# 处理图片序列 +async def pic_to_seg(pic_list: list) -> str: + # 处理图片序列 + pic_segs = [] + for pic_url in pic_list: + pic_result = await BiliRequestUtils.pic_to_file(url=pic_url) + if pic_result.error: + logger.warning(f'BiliDynamic get base64pic failed, error: {pic_result.info}, pic url: {pic_url}') + pic_segs.append(str(MessageSegment.image(pic_result.result))) + pic_seg = '\n'.join(pic_segs) + return pic_seg + + +# 检查单个用户动态的函数 +async def dynamic_checker(user_id: int, bots: List[Bot]): + # 获取动态并返回动态类型及内容 + user_dynamic_result = await BiliUser(user_id=user_id).get_dynamic_history() + if user_dynamic_result.error: + logger.error(f'bilibili_dynamic_monitor: 获取用户 {user_id} 动态失败, error: {user_dynamic_result.info}') + + # 解析动态内容 + dynamics_data = [] + for data in user_dynamic_result.result: + data_parse_result = BiliDynamic.data_parser(dynamic_data=data) + if data_parse_result.error: + logger.error(f'bilibili_dynamic_monitor: 解析新动态时发生了错误, error: {data_parse_result.info}') + continue + dynamics_data.append(data_parse_result) + + # 用户所有的动态id + exist_dynamic_result = await DBDynamic.list_dynamic_by_uid(uid=user_id) + if exist_dynamic_result.error: + logger.error(f'bilibili_dynamic_monitor: 获取用户 {user_id} 已有动态失败, error: {exist_dynamic_result.info}') + return + user_dynamic_list = [int(x) for x in exist_dynamic_result.result] + + new_dynamic_data = [data for data in dynamics_data if data.result.dynamic_id not in user_dynamic_list] + + subscription = DBSubscription(sub_type=2, sub_id=user_id) + + for data in new_dynamic_data: + dynamic_info = data.result + dynamic_card = dynamic_info.data + + dynamic_id = dynamic_info.dynamic_id + user_name = dynamic_info.user_name + desc = dynamic_info.desc + url = dynamic_info.url + + content = dynamic_card.content + title = dynamic_card.title + description = dynamic_card.description + + # 转发的动态 + if dynamic_info.type == 1: + # 转发的动态还需要获取原动态信息 + orig_dy_info_result = await BiliDynamic(dynamic_id=dynamic_info.orig_dy_id).get_info() + if orig_dy_info_result.success(): + orig_dy_data_result = BiliDynamic.data_parser(dynamic_data=orig_dy_info_result.result) + if orig_dy_data_result.success(): + # 原动态type=2, 8 或 4200, 带图片 + if orig_dy_data_result.result.type in [2, 8, 4200]: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=orig_dy_data_result.result.data.pictures) + orig_user = orig_dy_data_result.result.user_name + orig_contant = orig_dy_data_result.result.data.content + if not orig_contant: + orig_contant = orig_dy_data_result.result.data.title + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n" \ + f"@{orig_user}: {orig_contant}\n{pic_seg}" + # 原动态type=32 或 512, 为番剧类型 + elif orig_dy_data_result.result.type in [32, 512]: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=orig_dy_data_result.result.data.pictures) + orig_user = orig_dy_data_result.result.user_name + orig_title = orig_dy_data_result.result.data.title + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n" \ + f"@{orig_user}: {orig_title}\n{pic_seg}" + # 原动态为其他类型, 无图 + else: + orig_user = orig_dy_data_result.result.user_name + orig_contant = orig_dy_data_result.result.data.content + if not orig_contant: + orig_contant = orig_dy_data_result.result.data.title + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n" \ + f"@{orig_user}: {orig_contant}" + else: + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n@Unknown: 获取原动态失败" + else: + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n@Unknown: 获取原动态失败" + # 原创的动态(有图片) + elif dynamic_info.type == 2: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=dynamic_info.data.pictures) + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{pic_seg}" + # 原创的动态(无图片) + elif dynamic_info.type == 4: + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}" + # 视频 + elif dynamic_info.type == 8: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=dynamic_info.data.pictures) + if content: + msg = f"{user_name}{desc}!\n\n《{title}》\n\n“{content}”\n{url}\n{pic_seg}" + else: + msg = f"{user_name}{desc}!\n\n《{title}》\n\n{description}\n{url}\n{pic_seg}" + # 小视频 + elif dynamic_info.type == 16: + msg = f"{user_name}{desc}!\n\n“{content}”\n{url}" + # 番剧 + elif dynamic_info.type in [32, 512]: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=dynamic_info.data.pictures) + msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" + # 文章 + elif dynamic_info.type == 64: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=dynamic_info.data.pictures) + msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" + # 音频 + elif dynamic_info.type == 256: + # 处理图片序列 + pic_seg = await pic_to_seg(pic_list=dynamic_info.data.pictures) + msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" + # B站活动相关 + elif dynamic_info.type == 2048: + if description: + msg = f"{user_name}{desc}!\n\n【{title} - {description}】\n\n“{content}”\n{url}" + else: + msg = f"{user_name}{desc}!\n\n【{title}】\n“{content}”\n\n{url}" + else: + logger.warning(f"未知的动态类型: {type}, id: {dynamic_id}") + continue + + # 向群组和好友推送消息 + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='BiliDynamicNotice') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=msg) + await msg_sender.safe_broadcast_friends_subscription(subscription=subscription, message=msg) + + # 更新动态内容到数据库 + # 向数据库中写入动态信息 + dynamic = DBDynamic(uid=user_id, dynamic_id=dynamic_id) + _res = await dynamic.add(dynamic_type=dynamic_info.type, content=content) + if _res.success(): + logger.info(f"向数据库写入动态信息: {dynamic_id} 成功") + else: + logger.error(f"向数据库写入动态信息: {dynamic_id} 失败, error: {_res.info}") + + +# 用于首次订阅时刷新数据库动态信息 +async def init_user_dynamic(user_id: int): + # 暂停计划任务避免中途检查更新 + scheduler.pause() + await dynamic_checker(user_id=user_id, bots=[]) + scheduler.resume() + logger.info(f'Init new subscription user {user_id} dynamic completed.') + + +# 动态检查主函数 async def bilibili_dynamic_monitor(): logger.debug(f"bilibili_dynamic_monitor: checking started") # 获取当前bot列表 - bots = [] - for bot_id, bot in get_bots().items(): - bots.append(bot) - - # 获取所有有通知权限的群组 - t = DBTable(table_name='Group') - group_res = await t.list_col_with_condition('group_id', 'notice_permissions', 1) - all_noitce_groups = [int(x) for x in group_res.result] - - # 获取所有启用了私聊功能的好友 - friend_res = await DBFriend.list_exist_friends_by_private_permission(private_permission=1) - all_noitce_friends = [int(x) for x in friend_res.result] + bots = [bot for bot_id, bot in get_bots().items()] # 获取订阅表中的所有动态订阅 - t = DBTable(table_name='Subscription') - sub_res = await t.list_col_with_condition('sub_id', 'sub_type', 2) + sub_res = await DBSubscription.list_sub_by_type(sub_type=2) check_sub = [int(x) for x in sub_res.result] if not check_sub: logger.debug(f'bilibili_dynamic_monitor: no dynamic subscription, ignore.') return - # 处理图片序列 - async def pic2base64(pic_list: list) -> str: - # 处理图片序列 - pic_segs = [] - for pic_url in pic_list: - pic_result = await BiliRequestUtils.pic_2_base64(url=pic_url) - pic_b64 = pic_result.result - pic_segs.append(str(MessageSegment.image(pic_b64))) - pic_seg = '\n'.join(pic_segs) - return pic_seg - - # 注册一个异步函数用于检查动态 - async def check_dynamic(user_id: int): - # 获取动态并返回动态类型及内容 - user_dynamic_result = await BiliUser(user_id=user_id).get_dynamic_history() - if user_dynamic_result.error: - logger.error(f'bilibili_dynamic_monitor: 获取用户 {user_id} 动态失败, error: {user_dynamic_result.info}') - - # 解析动态内容 - dynamics_data = [] - for data in user_dynamic_result.result: - data_parse_result = BiliDynamic.data_parser(dynamic_data=data) - if data_parse_result.error: - logger.error(f'bilibili_dynamic_monitor: 解析新动态时发生了错误, error: {data_parse_result.info}') - continue - dynamics_data.append(data_parse_result) - - # 用户所有的动态id - dynamic_table = DBTable(table_name='Bilidynamic') - exist_dynamic_result = await dynamic_table.list_col_with_condition('dynamic_id', 'uid', user_id) - if exist_dynamic_result.error: - logger.error(f'bilibili_dynamic_monitor: 获取用户 {user_id} 已有动态失败, error: {exist_dynamic_result.info}') - return - user_dynamic_list = [int(x) for x in exist_dynamic_result.result] - - new_dynamic_data = [data for data in dynamics_data if data.result.dynamic_id not in user_dynamic_list] - - sub = DBSubscription(sub_type=2, sub_id=user_id) - - # 获取订阅了该直播间的所有群 - sub_group_res = await sub.sub_group_list() - sub_group = sub_group_res.result - # 需通知的群 - notice_groups = list(set(all_noitce_groups) & set(sub_group)) - - # 获取订阅了该直播间的所有好友 - sub_friend_res = await sub.sub_user_list() - sub_friend = sub_friend_res.result - # 需通知的好友 - notice_friends = list(set(all_noitce_friends) & set(sub_friend)) - - for data in new_dynamic_data: - dynamic_info = data.result - dynamic_card = dynamic_info.data - - dynamic_id = dynamic_info.dynamic_id - user_name = dynamic_info.user_name - desc = dynamic_info.desc - url = dynamic_info.url - - content = dynamic_card.content - title = dynamic_card.title - description = dynamic_card.description - - # 转发的动态 - if dynamic_info.type == 1: - # 转发的动态还需要获取原动态信息 - orig_dy_info_result = await BiliDynamic(dynamic_id=dynamic_info.orig_dy_id).get_info() - if orig_dy_info_result.success(): - orig_dy_data_result = BiliDynamic.data_parser(dynamic_data=orig_dy_info_result.result) - if orig_dy_data_result.success(): - # 原动态type=2 或 8, 带图片 - if orig_dy_data_result.result.type in [2, 8]: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=orig_dy_data_result.result.data.pictures) - orig_user = orig_dy_data_result.result.user_name - orig_contant = orig_dy_data_result.result.data.content - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n" \ - f"@{orig_user}: {orig_contant}\n{pic_seg}" - # 原动态为其他类型, 无图 - else: - orig_user = orig_dy_data_result.result.user_name - orig_contant = orig_dy_data_result.result.data.content - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n" \ - f"@{orig_user}: {orig_contant}" - else: - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n@Unknown: 获取原动态失败" - else: - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{'=' * 16}\n@Unknown: 获取原动态失败" - # 原创的动态(有图片) - elif dynamic_info.type == 2: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=dynamic_info.data.pictures) - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}\n{pic_seg}" - # 原创的动态(无图片) - elif dynamic_info.type == 4: - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}" - # 视频 - elif dynamic_info.type == 8: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=dynamic_info.data.pictures) - if content: - msg = f"{user_name}{desc}!\n\n《{title}》\n\n“{content}”\n{url}\n{pic_seg}" - else: - msg = f"{user_name}{desc}!\n\n《{title}》\n\n{description}\n{url}\n{pic_seg}" - # 小视频 - elif dynamic_info.type == 16: - msg = f"{user_name}{desc}!\n\n“{content}”\n{url}" - # 番剧 - elif dynamic_info.type in [32, 512]: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=dynamic_info.data.pictures) - msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" - # 文章 - elif dynamic_info.type == 64: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=dynamic_info.data.pictures) - msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" - # 音频 - elif dynamic_info.type == 256: - # 处理图片序列 - pic_seg = await pic2base64(pic_list=dynamic_info.data.pictures) - msg = f"{user_name}{desc}!\n\n《{title}》\n\n{content}\n{url}\n{pic_seg}" - # B站活动相关 - elif dynamic_info.type == 2048: - if description: - msg = f"{user_name}{desc}!\n\n【{title} - {description}】\n\n“{content}”\n{url}" - else: - msg = f"{user_name}{desc}!\n\n【{title}】\n“{content}”\n\n{url}" - else: - logger.warning(f"未知的动态类型: {type}, id: {dynamic_id}") - continue - - # 向群组发送消息 - for group_id in notice_groups: - for _bot in bots: - try: - await _bot.call_api(api='send_group_msg', group_id=group_id, message=msg) - logger.info(f"向群组: {group_id} 发送新动态通知: {dynamic_id}") - except Exception as _e: - logger.warning(f"向群组: {group_id} 发送新动态通知: {dynamic_id} 失败, error: {repr(_e)}") - continue - # 向好友发送消息 - for friend_user_id in notice_friends: - for _bot in bots: - try: - await _bot.call_api(api='send_private_msg', user_id=friend_user_id, message=msg) - logger.info(f"向好友: {friend_user_id} 发送新动态通知: {dynamic_id}") - except Exception as _e: - logger.warning(f"向好友: {friend_user_id} 发送新动态通知: {dynamic_id} 失败, error: {repr(_e)}") - continue - - # 更新动态内容到数据库 - # 向数据库中写入动态信息 - dynamic = DBDynamic(uid=user_id, dynamic_id=dynamic_id) - _res = await dynamic.add(dynamic_type=dynamic_info.type, content=content) - if _res.success(): - logger.info(f"向数据库写入动态信息: {dynamic_id} 成功") - else: - logger.error(f"向数据库写入动态信息: {dynamic_id} 失败, error: {_res.info}") - # 启用了检查池模式 if ENABLE_DYNAMIC_CHECK_POOL_MODE: global checking_pool @@ -255,11 +243,11 @@ async def check_dynamic(user_id: int): # 看下checking_pool里面还剩多少 waiting_num = len(checking_pool) - # 默认单次检查并发数为2, 默认检查间隔为20s + # 默认单次检查并发数为3, 默认检查间隔为20s logger.debug(f'bili dynamic pool mode debug info, B_checking_pool: {checking_pool}') - if waiting_num >= 2: + if waiting_num >= 3: # 抽取检查对象 - now_checking = random.sample(checking_pool, k=2) + now_checking = random.sample(checking_pool, k=3) # 更新checking_pool checking_pool = [x for x in checking_pool if x not in now_checking] else: @@ -271,7 +259,7 @@ async def check_dynamic(user_id: int): # 检查now_checking里面的直播间(异步) tasks = [] for uid in now_checking: - tasks.append(check_dynamic(uid)) + tasks.append(dynamic_checker(user_id=uid, bots=bots)) try: await asyncio.gather(*tasks) logger.debug(f"bilibili_dynamic_monitor: pool mode enable, checking completed, " @@ -281,19 +269,18 @@ async def check_dynamic(user_id: int): # 没有启用检查池模式 else: - # 检查所有在订阅表里面的直播间(异步) + # 检查所有在订阅表里面的动态(异步) tasks = [] for uid in check_sub: - tasks.append(check_dynamic(uid)) + tasks.append(dynamic_checker(user_id=uid, bots=bots)) try: await asyncio.gather(*tasks) logger.debug(f"bilibili_dynamic_monitor: pool mode disable, checking completed, " f"checked: {', '.join([str(x) for x in check_sub])}.") except Exception as e: - logger.error(f'bilibili_dynamic_monitor: pool mode disable, error occurred in checking {repr(e)}') + logger.error(f'bilibili_dynamic_monitor: pool mode disable, error occurred in checking {repr(e)}') -# 分时间段创建计划任务, 夜间闲时降低检查频率 # 根据检查池模式初始化检查时间间隔 if ENABLE_DYNAMIC_CHECK_POOL_MODE: # 检查池启用 @@ -313,7 +300,7 @@ async def check_dynamic(user_id: int): # timezone=None, id='bilibili_dynamic_monitor_pool_enable', coalesce=True, - misfire_grace_time=30 + misfire_grace_time=20 ) else: # 检查池禁用 @@ -337,5 +324,6 @@ async def check_dynamic(user_id: int): ) __all__ = [ - 'scheduler' + 'scheduler', + 'init_user_dynamic' ] diff --git a/omega_miya/plugins/bilibili_live_monitor/__init__.py b/omega_miya/plugins/bilibili_live_monitor/__init__.py index 2bda6dcc..ac0ae1fe 100644 --- a/omega_miya/plugins/bilibili_live_monitor/__init__.py +++ b/omega_miya/plugins/bilibili_live_monitor/__init__.py @@ -5,7 +5,7 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER, PRIVATE_FRIEND -from omega_miya.utils.Omega_Base import DBGroup, DBFriend, DBSubscription, Result +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBFriend, DBSubscription, Result from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state from omega_miya.utils.bilibili_utils import BiliLiveRoom from .data_source import BiliLiveChecker @@ -22,6 +22,10 @@ **Permission** Friend Private Command & Lv.20 +or AuthNode + +**AuthNode** +basic **Usage** **GroupAdmin and SuperUser Only** @@ -30,8 +34,13 @@ /B站直播间 清空订阅 /B站直播间 订阅列表''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 @@ -42,7 +51,8 @@ state=init_permission_state( name='bilibili_live', command=True, - level=20), + level=20, + auth_node='basic'), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER | PRIVATE_FRIEND, priority=20, block=True) @@ -148,31 +158,33 @@ async def handle_check(bot: Bot, event: MessageEvent, state: T_State): await bilibili_live.finish(f'{sub_command}失败了QAQ, 可能并未订阅该用户, 或请稍后再试~') -async def sub_list(bot: Bot, event: MessageEvent, state: T_State) -> Result.ListResult: +async def sub_list(bot: Bot, event: MessageEvent, state: T_State) -> Result.TupleListResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) result = await group.subscription_list_by_type(sub_type=1) return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) result = await friend.subscription_list_by_type(sub_type=1) return result else: - return Result.ListResult(error=True, info='Illegal event', result=[]) + return Result.TupleListResult(error=True, info='Illegal event', result=[]) async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) room_id = state['room_id'] sub = DBSubscription(sub_type=1, sub_id=room_id) _res = await sub.add(up_name=state.get('up_name'), live_info='B站直播间') if not _res.success(): return _res - _res = await group.subscription_add(sub=sub) + _res = await group.subscription_add(sub=sub, group_sub_info='B站直播间') if not _res.success(): return _res # 添加直播间时需要刷新全局监控列表 @@ -182,13 +194,13 @@ async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) room_id = state['room_id'] sub = DBSubscription(sub_type=1, sub_id=room_id) _res = await sub.add(up_name=state.get('up_name'), live_info='B站直播间') if not _res.success(): return _res - _res = await friend.subscription_add(sub=sub) + _res = await friend.subscription_add(sub=sub, user_sub_info='B站直播间') if not _res.success(): return _res # 添加直播间时需要刷新全局监控列表 @@ -201,9 +213,10 @@ async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) room_id = state['room_id'] _res = await group.subscription_del(sub=DBSubscription(sub_type=1, sub_id=room_id)) if not _res.success(): @@ -212,7 +225,7 @@ async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) room_id = state['room_id'] _res = await friend.subscription_del(sub=DBSubscription(sub_type=1, sub_id=room_id)) if not _res.success(): @@ -224,9 +237,10 @@ async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntRe async def sub_clear(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) if isinstance(event, GroupMessageEvent): group_id = event.group_id - group = DBGroup(group_id=group_id) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) _res = await group.subscription_clear_by_type(sub_type=1) if not _res.success(): return _res @@ -234,7 +248,7 @@ async def sub_clear(bot: Bot, event: MessageEvent, state: T_State) -> Result.Int return result elif isinstance(event, PrivateMessageEvent): user_id = event.user_id - friend = DBFriend(user_id=user_id) + friend = DBFriend(user_id=user_id, self_bot=self_bot) _res = await friend.subscription_clear_by_type(sub_type=1) if not _res.success(): return _res diff --git a/omega_miya/plugins/bilibili_live_monitor/data_source.py b/omega_miya/plugins/bilibili_live_monitor/data_source.py index f9ad3bd7..ec69c2b8 100644 --- a/omega_miya/plugins/bilibili_live_monitor/data_source.py +++ b/omega_miya/plugins/bilibili_live_monitor/data_source.py @@ -3,10 +3,11 @@ from dataclasses import dataclass from typing import List, Union from nonebot import logger -from nonebot.adapters import Bot +from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp import MessageSegment -from omega_miya.utils.Omega_Base import DBSubscription, DBHistory, DBTable, Result +from omega_miya.utils.Omega_Base import DBSubscription, DBHistory, Result from omega_miya.utils.bilibili_utils import BiliLiveRoom, BiliUser, BiliRequestUtils, BiliInfo +from omega_miya.utils.Omega_plugin_utils import MsgSender # 初始化直播间标题, 状态 @@ -72,9 +73,9 @@ async def init_global_live_info(cls): logger.opt(colors=True).warning(f'Bilibili 登录状态异常: {cookies_result.info}! 建议在配置中正确设置cookies!') logger.opt(colors=True).info('init_live_info: 初始化B站直播间监控列表...') - t = DBTable(table_name='Subscription') + tasks = [] - sub_res = await t.list_col_with_condition('sub_id', 'sub_type', 1) + sub_res = await DBSubscription.list_sub_by_type(sub_type=1) for sub_id in sub_res.result: tasks.append(BiliLiveChecker(room_id=sub_id).init_live_info()) try: @@ -155,12 +156,14 @@ async def title_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveRo # 直播过程中标题更新 elif live_info.status == 1 and live_info.title != live_title[self.room_id]: if live_info.cover_img: - cover_pic_result = await BiliRequestUtils.pic_2_base64(url=live_info.cover_img) + cover_pic_result = await BiliRequestUtils.pic_to_file(url=live_info.cover_img) if cover_pic_result.success(): # 发送的消息 msg = f"{up_name}的直播间换标题啦!\n\n【{live_info.title}】\n" \ f"{MessageSegment.image(cover_pic_result.result)}" else: + logger.warning(f'BiliLive get base64pic failed, ' + f'error: {cover_pic_result.info}, pic url: {live_info.cover_img}') msg = f"{up_name}的直播间换标题啦!\n\n【{live_info.title}】" else: msg = f"{up_name}的直播间换标题啦!\n\n【{live_info.title}】" @@ -188,6 +191,11 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR if live_info.status != live_status[self.room_id]: # 现在状态为未开播 if live_info.status == 0: + # 只有当从直播中状态切换到下播状态时才通知, 避免下播与轮播之间切换时发送通知 + if live_status[self.room_id] == 1: + action = True + else: + action = False # 事件记录写入数据库 live_end_info = f"LiveEnd! Room: {self.room_id}/{up_name}" new_event = DBHistory(time=int(time.time()), self_id=-1, post_type='bilibili', @@ -197,12 +205,12 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR # 更新直播间状态 live_status[self.room_id] = live_info.status - logger.info(f"直播间: {self.room_id}/{up_name} 下播了") + logger.info(f"直播间: {self.room_id}/{up_name} 状态切换为下播.") # 发送的消息 msg = f'{up_name}下播了' - return self.LiveRoomCheckerResult(error=False, changed=True, action=True, info='Success', + return self.LiveRoomCheckerResult(error=False, changed=True, action=action, info='Success', original=old_status, new=live_info.status, result=msg) # 现在状态为直播中 @@ -217,15 +225,17 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR # 更新直播间状态 live_status[self.room_id] = live_info.status - logger.info(f"直播间: {self.room_id}/{up_name} 开播了") + logger.info(f"直播间: {self.room_id}/{up_name} 状态切换为开播.") # 发送的消息 if live_info.cover_img: - cover_pic_result = await BiliRequestUtils.pic_2_base64(url=live_info.cover_img) + cover_pic_result = await BiliRequestUtils.pic_to_file(url=live_info.cover_img) if cover_pic_result.success(): msg = f"{live_info.live_time}\n{up_name}开播啦!\n\n【{live_info.title}】" \ f"\n{MessageSegment.image(cover_pic_result.result)}" else: + logger.warning(f'BiliLive get base64pic failed, ' + f'error: {cover_pic_result.info}, pic url: {live_info.cover_img}') msg = f"{live_info.live_time}\n{up_name}开播啦!\n\n【{live_info.title}】" else: msg = f"{live_info.live_time}\n{up_name}开播啦!\n\n【{live_info.title}】" @@ -235,6 +245,11 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR # 现在状态为未开播(轮播中) elif live_info.status == 2: + # 只有当从直播中状态切换到下播状态时才通知, 避免下播与轮播之间切换时发送通知 + if live_status[self.room_id] == 1: + action = True + else: + action = False # 事件记录写入数据库 live_end_info = f"LiveEnd! Room: {self.room_id}/{up_name}" new_event = DBHistory(time=int(time.time()), self_id=-1, post_type='bilibili', @@ -244,12 +259,12 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR # 更新直播间状态 live_status[self.room_id] = live_info.status - logger.info(f"直播间: {self.room_id}/{up_name} 下播了(轮播中)") + logger.info(f"直播间: {self.room_id}/{up_name} 状态切换为下播(轮播中).") # 发送的消息 msg = f'{up_name}下播了(轮播中)' - return self.LiveRoomCheckerResult(error=False, changed=True, action=True, info='Success', + return self.LiveRoomCheckerResult(error=False, changed=True, action=action, info='Success', original=old_status, new=live_info.status, result=msg) # 遇到的奇怪的状态 @@ -271,84 +286,37 @@ async def status_change_checker(self, live_info: BiliInfo.LiveRoomInfo) -> LiveR return self.LiveRoomCheckerResult(error=False, changed=False, action=False, info='Success', original=old_status, new=live_info.status, result='') - async def broadcaster( - self, live_info: BiliInfo.LiveRoomInfo, bots: List[Bot], all_groups: List[int], all_friends: List[int]): + async def broadcaster(self, live_info: BiliInfo.LiveRoomInfo, bots: List[Bot]): """ 检查直播间状态并向群组发送消息 :param live_info: 由 get_live_info 或 get_live_info_by_uid_list 获取的直播间信息 :param bots: bots 列表 - :param all_groups: 所有可能需要通知的群组列表 - :param all_friends: 所有可能需要通知的好友列表 """ global_check_result = await self.check_global_status() if global_check_result.error: return - sub = DBSubscription(sub_type=1, sub_id=self.room_id) - - # 获取订阅了该直播间的所有群 - sub_group_res = await sub.sub_group_list() - sub_group = sub_group_res.result - # 需通知的群 - notice_group = list(set(all_groups) & set(sub_group)) - - # 获取订阅了该直播间的所有好友 - sub_friend_res = await sub.sub_user_list() - sub_friend = sub_friend_res.result - # 需通知的好友 - notice_friends = list(set(all_friends) & set(sub_friend)) + subscription = DBSubscription(sub_type=1, sub_id=self.room_id) # 标题变更检测 title_checker_result = await self.title_change_checker(live_info=live_info) if title_checker_result.success() and title_checker_result.action: # 通知有通知权限且订阅了该直播间的群 - for group_id in notice_group: - for _bot in bots: - try: - await _bot.call_api( - api='send_group_msg', group_id=group_id, message=title_checker_result.result) - logger.info(f"向群组: {group_id} 发送直播间: {self.room_id} 标题变更通知") - except Exception as e: - logger.warning(f"向群组: {group_id} 发送直播间: {self.room_id} 标题变更通知失败, error: {repr(e)}") - continue - # 通知有通知权限且订阅了该直播间的好友 - for user_id in notice_friends: - for _bot in bots: - try: - await _bot.call_api( - api='send_private_msg', user_id=user_id, message=title_checker_result.result) - logger.info(f"向好友: {user_id} 发送直播间: {self.room_id} 标题变更通知") - except Exception as e: - logger.warning(f"向好友: {user_id} 发送直播间: {self.room_id} 标题变更通知失败, error: {repr(e)}") - continue + msg = title_checker_result.result + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='BiliLiveTitleNotice') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=msg) + await msg_sender.safe_broadcast_friends_subscription(subscription=subscription, message=msg) # 状态变更检测 status_checker_result = await self.status_change_checker(live_info=live_info) if status_checker_result.success() and status_checker_result.action: # 通知有通知权限且订阅了该直播间的群 - up_name = live_up_name[self.room_id] - status = live_status[self.room_id] - for group_id in notice_group: - for _bot in bots: - try: - await _bot.call_api( - api='send_group_msg', group_id=group_id, message=status_checker_result.result) - logger.info( - f"向群组: {group_id} 发送直播间: {self.room_id}/{up_name} 直播通知, status: {status}") - except Exception as e: - logger.warning(f"向群组: {group_id} 发送直播间: {self.room_id}/{up_name} 直播通知失败, error: {repr(e)}") - continue - # 通知有通知权限且订阅了该直播间的好友 - for user_id in notice_friends: - for _bot in bots: - try: - await _bot.call_api( - api='send_private_msg', user_id=user_id, message=status_checker_result.result) - logger.info( - f"向好友: {user_id} 发送直播间: {self.room_id}/{up_name} 直播通知, status: {status}") - except Exception as e: - logger.warning(f"向好友: {user_id} 发送直播间: {self.room_id}/{up_name} 直播通知失败, error: {repr(e)}") - continue + msg = status_checker_result.result + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='BiliLiveStatusNotice') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=msg) + await msg_sender.safe_broadcast_friends_subscription(subscription=subscription, message=msg) __all__ = [ diff --git a/omega_miya/plugins/bilibili_live_monitor/monitor.py b/omega_miya/plugins/bilibili_live_monitor/monitor.py index 0482e8dd..3aa33632 100644 --- a/omega_miya/plugins/bilibili_live_monitor/monitor.py +++ b/omega_miya/plugins/bilibili_live_monitor/monitor.py @@ -1,7 +1,7 @@ import asyncio import random from nonebot import logger, require, get_driver, get_bots -from omega_miya.utils.Omega_Base import DBFriend, DBSubscription, DBTable +from omega_miya.utils.Omega_Base import DBSubscription from omega_miya.utils.bilibili_utils import BiliLiveRoom from .data_source import BiliLiveChecker from .config import Config @@ -42,8 +42,7 @@ ) async def live_db_upgrade(): logger.debug('live_db_upgrade: started upgrade subscription info') - t = DBTable(table_name='Subscription') - sub_res = await t.list_col_with_condition('sub_id', 'sub_type', 1) + sub_res = await DBSubscription.list_sub_by_type(sub_type=1) for sub_id in sub_res.result: sub = DBSubscription(sub_type=1, sub_id=sub_id) live_user_info_result = await BiliLiveRoom(room_id=sub_id).get_user_info() @@ -63,22 +62,10 @@ async def bilibili_live_monitor(): logger.debug(f"bilibili_live_monitor: checking started") # 获取当前bot列表 - bots = [] - for bot_id, bot in get_bots().items(): - bots.append(bot) - - # 获取所有有通知权限的群组 - t = DBTable(table_name='Group') - group_res = await t.list_col_with_condition('group_id', 'notice_permissions', 1) - all_noitce_groups = [int(x) for x in group_res.result] - - # 获取所有启用了私聊功能的好友 - friend_res = await DBFriend.list_exist_friends_by_private_permission(private_permission=1) - all_noitce_friends = [int(x) for x in friend_res.result] + bots = [bot for bot_id, bot in get_bots().items()] # 获取订阅表中的所有直播间订阅 - t = DBTable(table_name='Subscription') - sub_res = await t.list_col_with_condition('sub_id', 'sub_type', 1) + sub_res = await DBSubscription.list_sub_by_type(sub_type=1) check_sub = [int(x) for x in sub_res.result] if not check_sub: @@ -94,8 +81,7 @@ async def check_live(room_id: int): return live_info = live_info_result.result try: - await BiliLiveChecker(room_id=room_id).broadcaster( - live_info=live_info, bots=bots, all_groups=all_noitce_groups, all_friends=all_noitce_friends) + await BiliLiveChecker(room_id=room_id).broadcaster(live_info=live_info, bots=bots) except Exception as _e: logger.error(f'bilibili_live_monitor: 处理直播间 {room_id} 状态信息是发生错误: {repr(_e)}') @@ -128,8 +114,7 @@ async def check_live_by_rids(room_id_list: list): # 依次处理各直播间信息 for room_id, live_info in live_info_.items(): try: - await BiliLiveChecker(room_id=room_id).broadcaster( - live_info=live_info, bots=bots, all_groups=all_noitce_groups, all_friends=all_noitce_friends) + await BiliLiveChecker(room_id=room_id).broadcaster(live_info=live_info, bots=bots) except Exception as _e: logger.error(f'bilibili_live_monitor: 处理直播间 {room_id} 状态信息是发生错误: {repr(_e)}') continue @@ -154,11 +139,11 @@ async def check_live_by_rids(room_id_list: list): # 看下checking_pool里面还剩多少 waiting_num = len(checking_pool) - # 默认单次检查并发数为2, 默认检查间隔为20s + # 默认单次检查并发数为3, 默认检查间隔为20s logger.debug(f'bili live pool mode debug info, B_checking_pool: {checking_pool}') - if waiting_num >= 2: + if waiting_num >= 3: # 抽取检查对象 - now_checking = random.sample(checking_pool, k=2) + now_checking = random.sample(checking_pool, k=3) # 更新checking_pool checking_pool = [x for x in checking_pool if x not in now_checking] else: @@ -191,7 +176,6 @@ async def check_live_by_rids(room_id_list: list): logger.error(f'bilibili_live_monitor: pool mode disable, error occurred in checking {repr(e)}') -# 分时间段创建计划任务, 夜间闲时降低检查频率 # 根据检查池模式初始化检查时间间隔 if ENABLE_NEW_LIVE_API: # 使用新api diff --git a/omega_miya/plugins/calculator/__init__.py b/omega_miya/plugins/calculator/__init__.py new file mode 100644 index 00000000..10ff53b3 --- /dev/null +++ b/omega_miya/plugins/calculator/__init__.py @@ -0,0 +1,99 @@ +""" +@Author : Ailitonia +@Date : 2021/07/18 15:39 +@FileName : calculator.py +@Project : nonebot2_miya +@Description : 简易计算器 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import on_command, export, logger +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state +from omega_miya.utils.dice_utils import BaseCalculator +from omega_miya.utils.dice_utils.exception import CalculateException + + +# Custom plugin usage text +__plugin_name__ = '计算器' +__plugin_usage__ = r'''【简易计算器】 +只能计算加减乘除和乘方! + +**Permission** +Command & Lv.10 +or AuthNode + +**AuthNode** +basic + +**Usage** +/计算器 [算式]''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +# 注册事件响应器 +calculator = on_command( + 'Calculator', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='calculator', + command=True, + level=10, + auth_node='basic'), + aliases={'calculator', '计算器', '计算'}, + permission=GROUP | PRIVATE_FRIEND, + priority=20, + block=True) + + +# 修改默认参数处理 +@calculator.args_parser +async def parse(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await calculator.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await calculator.finish('操作已取消') + + +@calculator.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['expression'] = args[0] + else: + await calculator.finish('参数错误QAQ') + + +@calculator.got('expression', prompt='请输入你想要计算的算式(只支持加减乘除和乘方):') +async def handle_calculator(bot: Bot, event: MessageEvent, state: T_State): + expression = state['expression'] + if len(expression) >= 128: + logger.warning(f'Calculator | 超过长度限制的算式: {expression}') + await calculator.finish('算式太长了QAQ') + + try: + result = await BaseCalculator(expression=expression).aio_std_calculate() + except CalculateException as e: + logger.warning(f'Calculator | 计算失败, error: {repr(e)}') + await calculator.finish(f'计算失败QAQ, {e.reason}') + return + except Exception as e: + logger.error(f'Calculator | 计算失败, error: {repr(e)}') + await calculator.finish(f'计算失败QAQ,也许算式超出计算范围了') + return + await calculator.finish(f'{expression}的计算结果是:\n\n{result}') diff --git a/omega_miya/plugins/draw/__init__.py b/omega_miya/plugins/draw/__init__.py index 8e2b36ea..e28d4e13 100644 --- a/omega_miya/plugins/draw/__init__.py +++ b/omega_miya/plugins/draw/__init__.py @@ -1,3 +1,5 @@ +import re +import random from nonebot import CommandGroup, export, logger from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot @@ -13,20 +15,27 @@ 没有保底的啦! 不要上头啊喂! 仅限群聊使用 +ps: 附带一个群组抽奖功能 **Permission** Command & Lv.10 +or AuthNode + +**AuthNode** +basic **CoolDown** 用户冷却时间 1 Minutes **Usage** -/抽卡 [卡组]''' +/抽卡 [卡组] +/抽奖 [人数]''' # 声明本插件可配置的权限节点 __plugin_auth_node__ = [ - PluginCoolDown.skip_auth_node + PluginCoolDown.skip_auth_node, + 'basic' ] # 声明本插件的冷却时间配置 @@ -39,17 +48,18 @@ # 注册事件响应器 Draw = CommandGroup( - 'draw', + 'Draw', # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 state=init_permission_state( name='draw', command=True, - level=10), + level=10, + auth_node='basic'), permission=GROUP, priority=10, block=True) -deck = Draw.command('deck', aliases={'抽卡'}) +deck = Draw.command('draw', aliases={'抽卡'}) # 修改默认参数处理 @@ -94,8 +104,62 @@ async def handle_deck(bot: Bot, event: GroupMessageEvent, state: T_State): if not draw_user: draw_user = event.sender.nickname - draw_result = draw_deck(_draw)(user_id=user_id) + draw_result = draw_deck(_draw)(user_id) # 向用户发送结果 msg = f"{draw_user}抽卡【{_draw}】!!\n{'='*12}\n{draw_result}" await deck.finish(msg) + + +lottery = Draw.command('lottery', aliases={'抽奖'}) + + +# 修改默认参数处理 +@lottery.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await lottery.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await lottery.finish('操作已取消') + + +@lottery.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['lottery'] = args[0] + else: + await lottery.finish('参数错误QAQ') + + +@lottery.got('lottery', prompt='请输入抽奖人数') +async def handle_lottery(bot: Bot, event: GroupMessageEvent, state: T_State): + _lottery = state['lottery'] + if re.match(r'^\d+$', _lottery): + people_num = int(_lottery) + + group_member_list = await bot.get_group_member_list(group_id=event.group_id) + group_user_name_list = [] + + for user_info in group_member_list: + # 用户信息 + user_nickname = user_info['nickname'] + user_group_nickmane = user_info['card'] + if not user_group_nickmane: + user_group_nickmane = user_nickname + group_user_name_list.append(user_group_nickmane) + + if people_num > len(group_user_name_list): + await lottery.finish(f'【错误】抽奖人数大于群成员人数了QAQ') + elif people_num > 100: + await lottery.finish(f'【错误】抽奖人数太多啦QAQ') + + lottery_result = random.sample(group_user_name_list, k=people_num) + msg = '【' + str.join('】\n【', lottery_result) + '】' + await lottery.finish(f"抽奖人数: 【{people_num}】\n以下是中奖名单:\n{msg}") + else: + await lottery.finish(f'格式不对呢, 人数应该是数字') diff --git a/omega_miya/plugins/draw/data_source.py b/omega_miya/plugins/draw/data_source.py index 233535ff..54569d19 100644 --- a/omega_miya/plugins/draw/data_source.py +++ b/omega_miya/plugins/draw/data_source.py @@ -1,8 +1,11 @@ from .deck import * +from typing import Dict, Callable +T_DrawDeck = Callable[[int], str] + # Deck事件 -deck_list = { +deck_list: Dict[str, T_DrawDeck] = { '单张塔罗牌': one_tarot, '超能力': superpower, '程序员修行': course, @@ -11,5 +14,5 @@ } -def draw_deck(deck: str): +def draw_deck(deck: str) -> T_DrawDeck: return deck_list.get(deck) diff --git a/omega_miya/plugins/draw/deck/arknights.py b/omega_miya/plugins/draw/deck/arknights.py index 659546c1..34925690 100644 --- a/omega_miya/plugins/draw/deck/arknights.py +++ b/omega_miya/plugins/draw/deck/arknights.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Dict import random @@ -20,31 +20,40 @@ class UpEvent: zoom: float # up提升倍率 +# 用户抽取的保底概率提升计数 +USERS_UP_COUNT: Dict[int, int] = {} + + # 当期up干员 -UP_OPERATOR = [ +UP_OPERATOR: List[UpEvent] = [ UpEvent( star=6, operator=[ - Operator(name='森蚺/Eunectes', star=6, limited=False, recruit_only=False, event_only=False, - special_only=False), - Operator(name='阿/Aak', star=6, limited=False, recruit_only=False, event_only=False, special_only=False) + Operator(name='早露/Роса', star=6, limited=False, recruit_only=False, event_only=False, special_only=False), + Operator(name='安洁莉娜/Angelina', star=6, limited=False, recruit_only=False, event_only=False, + special_only=False) ], zoom=0.5 ), UpEvent( star=5, operator=[ - Operator(name='白面鸮/Ptilopsis', star=5, limited=False, recruit_only=False, event_only=False, + Operator(name='普罗旺斯/Provence', star=5, limited=False, recruit_only=False, event_only=False, special_only=False), - Operator(name='真理/Истина', star=5, limited=False, recruit_only=False, event_only=False, special_only=False), - Operator(name='蓝毒/Blue Poison', star=5, limited=False, recruit_only=False, event_only=False, + Operator(name='梅尔/Mayer', star=5, limited=False, recruit_only=False, event_only=False, special_only=False), + Operator(name='乌有/Mr.Nothing', star=5, limited=False, recruit_only=False, event_only=False, special_only=False) ], zoom=0.5 ) ] -ALL_OPERATOR = [ +ALL_OPERATOR: List[Operator] = [ + Operator(name='帕拉斯/Pallas', star=6, limited=False, recruit_only=False, event_only=False, special_only=False), + Operator(name='卡涅利安/Carnelian', star=6, limited=False, recruit_only=False, event_only=False, special_only=False), + Operator(name='绮良/Kirara', star=5, limited=False, recruit_only=False, event_only=False, special_only=False), + Operator(name='贝娜/Bena', star=5, limited=False, recruit_only=False, event_only=True, special_only=False), + Operator(name='深靛/Indigo', star=4, limited=False, recruit_only=False, event_only=False, special_only=False), Operator(name='浊心斯卡蒂/Skadi the Corrupting Heart', star=6, limited=True, recruit_only=False, event_only=False, special_only=False), Operator(name="凯尔希/Kal'tsit", star=6, limited=False, recruit_only=False, event_only=False, special_only=False), @@ -241,9 +250,32 @@ class UpEvent: ] -def draw_one_operator() -> str: - # 先决定出的星级 - star = random.sample([6, 5, 4, 3], counts=[2, 8, 50, 40], k=1)[0] +def draw_one_operator(user_id: int) -> str: + global USERS_UP_COUNT + draw_count = USERS_UP_COUNT.get(user_id, 0) + + # 首先要先决定出的星级 + if 0 <= draw_count <= 50: + # 没有抽过或者刚刚重置过, 无概率提升 + star = random.sample([6, 5, 4, 3], counts=[2, 8, 50, 40], k=1)[0] + USERS_UP_COUNT.update({user_id: draw_count + 1}) + elif 50 < draw_count <= 99: + # 触发概率提升 + if random.randint(1, 100) <= (draw_count - 49) * 2: + # 触发概率提升则为6星 + star = 6 + else: + # 否则则在5, 4, 3星中随机 + star = random.sample([5, 4, 3], counts=[8, 50, 40], k=1)[0] + USERS_UP_COUNT.update({user_id: draw_count + 1}) + else: + # 多半是出bug了, 强制重置次数 + star = random.sample([6, 5, 4, 3], counts=[2, 8, 50, 40], k=1)[0] + USERS_UP_COUNT.update({user_id: 1}) + + # 如果出6星了就重置up次数 + if star == 6: + USERS_UP_COUNT.update({user_id: 0}) # 生成对应卡池和处理up事件 up_event = [(x.zoom, x.operator) for x in UP_OPERATOR if x.star == star] @@ -279,7 +311,7 @@ def draw_one_arknights(user_id: int) -> str: up_up_operator = '\n'.join(up_operators) up_info = f'当期UP干员:\n{up_up_operator}' - acquire_operator = draw_one_operator() + acquire_operator = draw_one_operator(user_id=user_id) return f"获得了以下干员:\n{acquire_operator}\n{'='*12}\n{up_info}" @@ -294,7 +326,7 @@ def draw_ten_arknights(user_id: int) -> str: acquire_operators = [] for i in range(10): - acquire_operators.append(draw_one_operator()) + acquire_operators.append(draw_one_operator(user_id=user_id)) acquire_operator = '\n'.join(acquire_operators) diff --git a/omega_miya/plugins/http_cat/__init__.py b/omega_miya/plugins/http_cat/__init__.py new file mode 100644 index 00000000..bf017c61 --- /dev/null +++ b/omega_miya/plugins/http_cat/__init__.py @@ -0,0 +1,94 @@ +""" +@Author : Ailitonia +@Date : 2021/05/30 16:47 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : Get http cat +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import re +from nonebot import on_command, export, logger +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND +from nonebot.adapters.cqhttp.message import MessageSegment +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state +from .data_source import get_http_cat + + +# Custom plugin usage text +__plugin_name__ = 'HttpCat' +__plugin_usage__ = r'''【Http Cat】 +就是喵喵喵 + +**Permission** +Friend Private +Command & Lv.30 +or AuthNode + +**AuthNode** +basic + +**Usage** +/HttpCat ''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +# 注册事件响应器 +httpcat = on_command( + 'HttpCat', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='httpcat', + command=True, + level=30, + auth_node='basic'), + aliases={'httpcat', 'HTTPCAT'}, + permission=GROUP | PRIVATE_FRIEND, + priority=20, + block=True) + + +# 修改默认参数处理 +@httpcat.args_parser +async def parse(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await httpcat.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await httpcat.finish('操作已取消') + + +@httpcat.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['code'] = args[0] + else: + await httpcat.finish('参数错误QAQ') + + +@httpcat.got('code', prompt='http code?') +async def handle_httpcat(bot: Bot, event: MessageEvent, state: T_State): + code = state['code'] + if not re.match(r'^\d+$', code): + await httpcat.finish('Http code is number!') + res = await get_http_cat(http_code=code) + if res.success() and res.result: + img_seg = MessageSegment.image(res.result) + await httpcat.finish(img_seg) + else: + await httpcat.finish('^QAQ^') diff --git a/omega_miya/plugins/http_cat/data_source.py b/omega_miya/plugins/http_cat/data_source.py new file mode 100644 index 00000000..c1b2379d --- /dev/null +++ b/omega_miya/plugins/http_cat/data_source.py @@ -0,0 +1,39 @@ +""" +@Author : Ailitonia +@Date : 2021/05/30 16:48 +@FileName : data_source.py +@Project : nonebot2_miya +@Description : Http Cat Utils +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import os +import pathlib +from nonebot import get_driver +from omega_miya.utils.Omega_plugin_utils import HttpFetcher +from omega_miya.utils.Omega_Base import Result + +global_config = get_driver().config +TMP_PATH = global_config.tmp_path_ + +API_URL = 'https://http.cat/' + + +async def get_http_cat(http_code: int) -> Result.TextResult: + file_name = f'{http_code}.jpg' + folder_path = os.path.abspath(os.path.join(TMP_PATH, 'http_cat')) + file_path = os.path.abspath(os.path.join(folder_path, file_name)) + if os.path.exists(file_path): + file_url = pathlib.Path(file_path).as_uri() + return Result.TextResult(error=False, info='Success, file exists', result=file_url) + + url = f'{API_URL}{http_code}.jpg' + fetcher = HttpFetcher(timeout=10, flag='http_cat') + result = await fetcher.download_file(url=url, path=folder_path, file_name=file_name) + + if result.success(): + file_url = pathlib.Path(result.result).as_uri() + return Result.TextResult(error=False, info='Success', result=file_url) + else: + return Result.TextResult(error=True, info=result.info, result='') diff --git a/omega_miya/plugins/maybe/__init__.py b/omega_miya/plugins/maybe/__init__.py index 28015fd4..9b6411b7 100644 --- a/omega_miya/plugins/maybe/__init__.py +++ b/omega_miya/plugins/maybe/__init__.py @@ -19,13 +19,22 @@ **Permission** Command & Lv.10 +or AuthNode + +**AuthNode** +basic **Usage** /求签 [所求之事] /DD老黄历''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 Maybe = CommandGroup( @@ -34,7 +43,8 @@ state=init_permission_state( name='maybe', command=True, - level=10), + level=10, + auth_node='basic'), permission=GROUP, priority=10, block=True) diff --git a/omega_miya/plugins/miya_button/__init__.py b/omega_miya/plugins/miya_button/__init__.py index 0d054c7c..cf227853 100644 --- a/omega_miya/plugins/miya_button/__init__.py +++ b/omega_miya/plugins/miya_button/__init__.py @@ -1,5 +1,6 @@ import re import os +import pathlib from nonebot import MatcherGroup, logger from nonebot.typing import T_State from nonebot.rule import to_me @@ -35,10 +36,13 @@ @miya_button.handle() async def handle_miya_button(bot: Bot, event: GroupMessageEvent, state: T_State): arg = str(event.get_plaintext()).strip().lower() - voice = re.sub('喵一个', '', arg) - voice_file = miya_voices.get_voice(keyword=voice) - if not os.path.exists(voice_file): - await miya_button.send('喵?') + keyword = re.sub('喵一个', '', arg) + voice_file = miya_voices.get_voice(keyword=keyword) + if not voice_file: + await miya_button.finish(f'{keyword}是什么不懂喵') + elif not os.path.exists(voice_file): + await miya_button.finish('喵?') else: - msg = MessageSegment.record(file=f'file:///{voice_file}') - await miya_button.send(msg) + file_url = pathlib.Path(voice_file).as_uri() + msg = MessageSegment.record(file=file_url) + await miya_button.finish(msg) diff --git a/omega_miya/plugins/miya_button/resources/data_classes.py b/omega_miya/plugins/miya_button/resources/data_classes.py index 1507d36f..a94a17d1 100644 --- a/omega_miya/plugins/miya_button/resources/data_classes.py +++ b/omega_miya/plugins/miya_button/resources/data_classes.py @@ -21,11 +21,14 @@ class Voice: voices: List[VoiceFile] def get_voice(self, keyword: str) -> Optional[str]: - result = [x for x in self.voices if x.name == keyword] - if not result: - result = [x for x in self.voices if x.tag == keyword] - if not result: + if keyword: + result = [x for x in self.voices if x.name == keyword] + if not result: + result = [x for x in self.voices if x.tag == keyword] + else: result = self.voices + if not result: + return None voice = random.choice(result) return os.path.abspath(os.path.join(voice.folder_path, voice.file_name)) diff --git a/omega_miya/plugins/miya_button/resources/miya_voices.py b/omega_miya/plugins/miya_button/resources/miya_voices.py index 1ede7caa..04cba2bb 100644 --- a/omega_miya/plugins/miya_button/resources/miya_voices.py +++ b/omega_miya/plugins/miya_button/resources/miya_voices.py @@ -43,7 +43,25 @@ VoiceFile(name='起床了大笨蛋', file_name='34.mp3', folder_path=__voices_folder, tag='阴阳'), VoiceFile(name='起来了dd', file_name='35.mp3', folder_path=__voices_folder, tag='阴阳'), VoiceFile(name='作业写了吗', file_name='36.mp3', folder_path=__voices_folder, tag='阴阳'), - VoiceFile(name='异世相遇尽享美味', file_name='37.mp3', folder_path=__voices_folder, tag='卖萌') + VoiceFile(name='异世相遇尽享美味', file_name='37.mp3', folder_path=__voices_folder, tag='卖萌'), + VoiceFile(name='猫猫坏笑', file_name='38.mp3', folder_path=__voices_folder, tag='阴阳'), + VoiceFile(name='猫猫反派笑', file_name='39.mp3', folder_path=__voices_folder, tag='阴阳'), + VoiceFile(name='猫猫反派笑', file_name='40.mp3', folder_path=__voices_folder, tag='阴阳'), + VoiceFile(name='猫猫反派笑', file_name='41.mp3', folder_path=__voices_folder, tag='阴阳'), + VoiceFile(name='niya异世相遇', file_name='42.mp3', folder_path=__voices_folder, tag='卖萌'), + VoiceFile(name='啊啊啊啊~', file_name='43.mp3', folder_path=__voices_folder, tag='怪叫'), + VoiceFile(name='这游戏有问题', file_name='44.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='嗷呜', file_name='45.mp3', folder_path=__voices_folder, tag='怪叫'), + VoiceFile(name='不要吃我', file_name='46.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='不要打我', file_name='47.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='别别别', file_name='48.mp3', folder_path=__voices_folder, tag='怪叫'), + VoiceFile(name='擦盘子', file_name='49.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='吵死了', file_name='50.mp3', folder_path=__voices_folder, tag='卖萌'), + VoiceFile(name='打嗝', file_name='51.mp3', folder_path=__voices_folder, tag='怪叫'), + VoiceFile(name='打哈欠', file_name='52.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='打嗝嗷嗷', file_name='53.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='喵喵喵喵喵喵', file_name='54.mp3', folder_path=__voices_folder, tag='普通'), + VoiceFile(name='晚安臭DD', file_name='55.mp3', folder_path=__voices_folder, tag='卖萌') ] ) diff --git a/omega_miya/plugins/nbnhhsh/__init__.py b/omega_miya/plugins/nbnhhsh/__init__.py index 4ff35fb4..163bd2e5 100644 --- a/omega_miya/plugins/nbnhhsh/__init__.py +++ b/omega_miya/plugins/nbnhhsh/__init__.py @@ -1,4 +1,3 @@ -import re from nonebot import on_command, export, logger from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot @@ -17,13 +16,21 @@ **Permission** Friend Private Command & Lv.30 +or AuthNode + +**AuthNode** +basic **Usage** /好好说话 [缩写]''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 @@ -33,7 +40,8 @@ state=init_permission_state( name='nbnhhsh', command=True, - level=30), + level=30, + auth_node='basic'), aliases={'hhsh', 'nbnhhsh'}, permission=GROUP | PRIVATE_FRIEND, priority=20, @@ -65,23 +73,20 @@ async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): @nbnhhsh.got('guess', prompt='有啥缩写搞不懂?') async def handle_nbnhhsh(bot: Bot, event: MessageEvent, state: T_State): guess = state['guess'] - if re.match(r'^[a-zA-Z0-9]+$', guess): - res = await get_guess(guess=guess) - if res.success() and res.result: - try: - data = dict(res.result[0]) - except Exception as e: - logger.error(f'nbnhhsh error: {repr(e)}') - await nbnhhsh.finish('发生了意外的错误QAQ, 请稍后再试') - return - if data.get('trans'): - trans = str.join('\n', data.get('trans')) - msg = f"为你找到了{guess}的以下解释:\n\n{trans}" - await nbnhhsh.finish(msg) - elif data.get('inputting'): - trans = str.join('\n', data.get('inputting')) - msg = f"为你找到了{guess}的以下解释:\n\n{trans}" - await nbnhhsh.finish(msg) - await nbnhhsh.finish(f'没有找到{guess}的相关解释QAQ') - else: - await nbnhhsh.finish('缩写仅支持字母加数字, 请重新输入') + res = await get_guess(guess=guess) + if res.success() and res.result: + try: + data = dict(res.result[0]) + except Exception as e: + logger.error(f'nbnhhsh error: {repr(e)}') + await nbnhhsh.finish('发生了意外的错误QAQ, 请稍后再试') + return + if data.get('trans'): + trans = str.join('\n', data.get('trans')) + msg = f"为你找到了{guess}的以下解释:\n\n{trans}" + await nbnhhsh.finish(msg) + elif data.get('inputting'): + trans = str.join('\n', data.get('inputting')) + msg = f"为你找到了{guess}的以下解释:\n\n{trans}" + await nbnhhsh.finish(msg) + await nbnhhsh.finish(f'没有找到{guess}的相关解释QAQ') diff --git a/omega_miya/plugins/pixiv/__init__.py b/omega_miya/plugins/pixiv/__init__.py index 484a788f..dab21572 100644 --- a/omega_miya/plugins/pixiv/__init__.py +++ b/omega_miya/plugins/pixiv/__init__.py @@ -1,12 +1,23 @@ import re -from nonebot import on_command, export, logger +import asyncio +from typing import Optional +from nonebot import on_command, export, logger, get_driver from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND from nonebot.adapters.cqhttp import MessageSegment, Message -from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PluginCoolDown, check_auth_node +from omega_miya.utils.Omega_Base import DBBot +from omega_miya.utils.Omega_plugin_utils import \ + init_export, init_permission_state, PluginCoolDown, PermissionChecker, MsgSender from omega_miya.utils.pixiv_utils import PixivIllust +from .config import Config + + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +ENABLE_NODE_CUSTOM = plugin_config.enable_node_custom + # Custom plugin usage text __plugin_name__ = 'Pixiv' @@ -96,70 +107,76 @@ async def handle_pixiv(bot: Bot, event: MessageEvent, state: T_State): mode = state['mode'] if mode == '日榜': await pixiv.send('稍等, 正在下载图片~') - rank_result = await PixivIllust.daily_ranking() + rank_result = await PixivIllust.get_ranking(mode='daily') if rank_result.error: logger.warning(f"User: {event.user_id} 获取Pixiv Rank失败, {rank_result.info}") await pixiv.finish('加载失败, 网络超时QAQ') + tasks = [] for rank, illust_data in dict(rank_result.result).items(): - rank += 1 - illust_id = illust_data.get('illust_id') - illust_title = illust_data.get('illust_title') - illust_uname = illust_data.get('illust_uname') - - image_result = await PixivIllust(pid=illust_id).pic_2_base64() - if image_result.success(): - msg = f'No.{rank} - ID: {illust_id}\n「{illust_title}」/「{illust_uname}」' - img_seg = MessageSegment.image(image_result.result) - await pixiv.send(Message(img_seg).append(msg)) - else: - logger.warning(f"下载图片失败, pid: {illust_id}, {image_result.info}") if rank >= 10: break + tasks.append(__handle_ranking_msg(rank=rank, illust_data=illust_data)) + ranking_msg_result = list(await asyncio.gather(*tasks)) + + # 根据ENABLE_NODE_CUSTOM处理消息发送 + if ENABLE_NODE_CUSTOM and isinstance(event, GroupMessageEvent): + msg_sender = MsgSender(bot=bot, log_flag='PixivDailyRanking') + await msg_sender.safe_send_group_node_custom(group_id=event.group_id, message_list=ranking_msg_result) + else: + for msg_seg in ranking_msg_result: + try: + await pixiv.send(msg_seg) + except Exception as e: + logger.warning(f'图片发送失败, user: {event.user_id}. error: {repr(e)}') elif mode == '周榜': await pixiv.send('稍等, 正在下载图片~') - rank_result = await PixivIllust.weekly_ranking() + rank_result = await PixivIllust.get_ranking(mode='weekly') if rank_result.error: logger.warning(f"User: {event.user_id} 获取Pixiv Rank失败, {rank_result.info}") await pixiv.finish('加载失败, 网络超时QAQ') + tasks = [] for rank, illust_data in dict(rank_result.result).items(): - rank += 1 - illust_id = illust_data.get('illust_id') - illust_title = illust_data.get('illust_title') - illust_uname = illust_data.get('illust_uname') - - image_result = await PixivIllust(pid=illust_id).pic_2_base64() - if image_result.success(): - msg = f'No.{rank} - ID: {illust_id}\n「{illust_title}」/「{illust_uname}」' - img_seg = MessageSegment.image(image_result.result) - await pixiv.send(Message(img_seg).append(msg)) - else: - logger.warning(f"下载图片失败, pid: {illust_id}, {image_result.info}") if rank >= 10: break + tasks.append(__handle_ranking_msg(rank=rank, illust_data=illust_data)) + ranking_msg_result = list(await asyncio.gather(*tasks)) + + # 根据ENABLE_NODE_CUSTOM处理消息发送 + if ENABLE_NODE_CUSTOM and isinstance(event, GroupMessageEvent): + msg_sender = MsgSender(bot=bot, log_flag='PixivWeeklyRanking') + await msg_sender.safe_send_group_node_custom(group_id=event.group_id, message_list=ranking_msg_result) + else: + for msg_seg in ranking_msg_result: + try: + await pixiv.send(msg_seg) + except Exception as e: + logger.warning(f'图片发送失败, user: {event.user_id}. error: {repr(e)}') elif mode == '月榜': await pixiv.send('稍等, 正在下载图片~') - rank_result = await PixivIllust.monthly_ranking() + rank_result = await PixivIllust.get_ranking(mode='monthly') if rank_result.error: logger.warning(f"User: {event.user_id} 获取Pixiv Rank失败, {rank_result.info}") await pixiv.finish('加载失败, 网络超时QAQ') + tasks = [] for rank, illust_data in dict(rank_result.result).items(): - rank += 1 - illust_id = illust_data.get('illust_id') - illust_title = illust_data.get('illust_title') - illust_uname = illust_data.get('illust_uname') - - image_result = await PixivIllust(pid=illust_id).pic_2_base64() - if image_result.success(): - msg = f'No.{rank} - ID: {illust_id}\n「{illust_title}」/「{illust_uname}」' - img_seg = MessageSegment.image(image_result.result) - await pixiv.send(Message(img_seg).append(msg)) - else: - logger.warning(f"下载图片失败, pid: {illust_id}, {image_result.info}") if rank >= 10: break + tasks.append(__handle_ranking_msg(rank=rank, illust_data=illust_data)) + ranking_msg_result = list(await asyncio.gather(*tasks)) + + # 根据ENABLE_NODE_CUSTOM处理消息发送 + if ENABLE_NODE_CUSTOM and isinstance(event, GroupMessageEvent): + msg_sender = MsgSender(bot=bot, log_flag='PixivMonthlyRanking') + await msg_sender.safe_send_group_node_custom(group_id=event.group_id, message_list=ranking_msg_result) + else: + for msg_seg in ranking_msg_result: + try: + await pixiv.send(msg_seg) + except Exception as e: + logger.warning(f'图片发送失败, user: {event.user_id}. error: {repr(e)}') elif re.match(r'^\d+$', mode): pid = mode logger.debug(f'开始获取Pixiv资源: {pid}.') @@ -174,10 +191,12 @@ async def handle_pixiv(bot: Bot, event: MessageEvent, state: T_State): if illust_data_result.result.get('is_r18'): if isinstance(event, PrivateMessageEvent): user_id = event.user_id - auth_checker = await check_auth_node(auth_id=user_id, auth_type='user', auth_node='pixiv.allow_r18') + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))).\ + check_auth_node(auth_id=user_id, auth_type='user', auth_node='pixiv.allow_r18') elif isinstance(event, GroupMessageEvent): group_id = event.group_id - auth_checker = await check_auth_node(auth_id=group_id, auth_type='group', auth_node='pixiv.allow_r18') + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))).\ + check_auth_node(auth_id=group_id, auth_type='group', auth_node='pixiv.allow_r18') else: auth_checker = 0 @@ -185,15 +204,23 @@ async def handle_pixiv(bot: Bot, event: MessageEvent, state: T_State): logger.warning(f"User: {event.user_id} 获取Pixiv资源 {pid} 被拒绝, 没有 allow_r18 权限") await pixiv.finish('R18禁止! 不准开车车!') + # 区分作品类型 + illust_type = illust_data_result.result.get('illust_type') await pixiv.send('稍等, 正在下载图片~') - illust_result = await illust.pic_2_base64() - if illust_result.success(): - msg = illust_result.info + illust_info_result = await illust.get_format_info_msg() + if illust_type == 2: + # 动图作品生成动图后发送 + illust_result = await illust.get_ugoira_gif_filepath() + else: + illust_result = await illust.get_file() + if illust_result.success() and illust_info_result.success(): + msg = illust_info_result.result img_seg = MessageSegment.image(illust_result.result) # 发送图片和图片信息 await pixiv.send(Message(img_seg).append(msg)) else: - logger.warning(f"User: {event.user_id} 获取Pixiv资源失败, 网络超时或 {pid} 不存在, {illust_result.info}") + logger.warning(f"User: {event.user_id} 获取Pixiv资源失败, 网络超时或 {pid} 不存在, " + f"{illust_info_result.info} // {illust_result.info}") await pixiv.send('加载失败, 网络超时或没有这张图QAQ') else: await pixiv.reject('你输入的命令好像不对呢……请输入"月榜"、"周榜"、"日榜"或者PixivID, 取消命令请发送【取消】:') @@ -264,7 +291,25 @@ async def handle_pixiv_dl(bot: Bot, event: GroupMessageEvent, state: T_State): await bot.call_api(api='upload_group_file', group_id=event.group_id, file=file_path, name=file_name) except Exception as e: logger.warning(f'User: {event.user_id} 下载Pixiv资源失败, 上传群文件失败: {repr(e)}') - await pixiv_dl.finish('上传图片到群文件失败QAQ, 请稍后再试') + await pixiv_dl.finish('上传图片到群文件失败QAQ, 可能获取上传结果超时但上传仍在进行中, 请等待1~2分钟后再重试') else: await pixiv_dl.finish('参数错误, pid应为纯数字') + + +# 处理Pixiv.__ranking榜单消息 +async def __handle_ranking_msg(rank: int, illust_data: dict) -> Optional[Message]: + rank += 1 + + illust_id = illust_data.get('illust_id') + illust_title = illust_data.get('illust_title') + illust_uname = illust_data.get('illust_uname') + + image_result = await PixivIllust(pid=illust_id).get_file() + if image_result.success(): + msg = f'No.{rank} - ID: {illust_id}\n「{illust_title}」/「{illust_uname}」' + img_seg = MessageSegment.image(image_result.result) + return Message(img_seg).append(msg) + else: + logger.warning(f"下载图片失败, pid: {illust_id}, {image_result.info}") + return None diff --git a/omega_miya/plugins/pixiv/config.py b/omega_miya/plugins/pixiv/config.py new file mode 100644 index 00000000..bfdedbad --- /dev/null +++ b/omega_miya/plugins/pixiv/config.py @@ -0,0 +1,21 @@ +""" +@Author : Ailitonia +@Date : 2021/06/13 18:48 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + # plugin custom config + # 启用使用群组转发自定义消息节点的模式发送信息 + # 发送速度受限于网络上传带宽, 有可能导致超时或发送失败, 请酌情启用 + enable_node_custom: bool = False + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/pixiv_monitor/__init__.py b/omega_miya/plugins/pixiv_monitor/__init__.py new file mode 100644 index 00000000..29fe4966 --- /dev/null +++ b/omega_miya/plugins/pixiv_monitor/__init__.py @@ -0,0 +1,265 @@ +""" +@Author : Ailitonia +@Date : 2021/06/01 22:06 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : Pixiv 用户作品订阅 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import re +from nonebot import on_command, export, logger +from nonebot.permission import SUPERUSER +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER, PRIVATE_FRIEND +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBFriend, DBSubscription, Result +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state +from omega_miya.utils.pixiv_utils import PixivUser +from .monitor import scheduler, init_new_add_sub + + +# Custom plugin usage text +__plugin_name__ = 'Pixiv画师订阅' +__plugin_usage__ = r'''【Pixiv画师订阅】 +随时更新Pixiv画师作品 +仅限群聊使用 + +**Permission** +Command & Lv.50 +or AuthNode + +**AuthNode** +basic + +**Usage** +**GroupAdmin and SuperUser Only** +/Pixiv画师 订阅 [UID] +/Pixiv画师 取消订阅 [UID] +/Pixiv画师 清空订阅 +/Pixiv画师 订阅列表''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + +# 注册事件响应器 +pixiv_user_artwork = on_command( + 'Pixiv画师', + aliases={'pixiv画师', 'p站画师', 'P站画师'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='pixiv_user_artwork', + command=True, + level=50, + auth_node='basic'), + permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, + priority=20, + block=True) + + +# 修改默认参数处理 +@pixiv_user_artwork.args_parser +async def parse(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await pixiv_user_artwork.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await pixiv_user_artwork.finish('操作已取消') + + +@pixiv_user_artwork.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['sub_command'] = args[0] + elif args and len(args) == 2: + state['sub_command'] = args[0] + state['uid'] = args[1] + else: + await pixiv_user_artwork.finish('参数错误QAQ') + + +@pixiv_user_artwork.got('sub_command', prompt='执行操作?\n【订阅/取消订阅/清空订阅/订阅列表】') +async def handle_sub_command_args(bot: Bot, event: MessageEvent, state: T_State): + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + msg = '本群已订阅以下Pixiv用户:\n' + else: + group_id = 'Private event' + msg = '你已订阅以下Pixiv用户:\n' + + if state['sub_command'] not in ['订阅', '取消订阅', '清空订阅', '订阅列表']: + await pixiv_user_artwork.finish('没有这个命令哦, 请在【订阅/取消订阅/清空订阅/订阅列表】中选择并重新发送') + if state['sub_command'] == '订阅列表': + _res = await sub_list(bot=bot, event=event, state=state) + if not _res.success(): + logger.error(f"查询Pixiv订阅失败, {group_id} / {event.user_id}, error: {_res.info}") + await pixiv_user_artwork.finish('查询Pixiv订阅失败QAQ, 请稍后再试吧') + if not _res.result: + msg = '当前没有任何Pixiv订阅' + else: + for sub_id, up_name in _res.result: + msg += f'\n【{sub_id}/{up_name}】' + await pixiv_user_artwork.finish(msg) + elif state['sub_command'] == '清空订阅': + state['uid'] = None + + +@pixiv_user_artwork.got('uid', prompt='请输入订阅Pixiv用户UID:') +async def handle_uid(bot: Bot, event: MessageEvent, state: T_State): + sub_command = state['sub_command'] + # 针对清空Pixiv订阅操作, 跳过获取Pixiv用户信息 + if state['sub_command'] == '清空订阅': + await pixiv_user_artwork.pause('【警告!】\n即将清空本所有订阅!\n请发送任意消息以继续操作:') + # Pixiv用户信息获取部分 + uid = state['uid'] + if not re.match(r'^\d+$', uid): + await pixiv_user_artwork.reject('这似乎不是UID呢, 请重新输入:') + _res = await PixivUser(uid=int(uid)).get_info() + if not _res.success(): + logger.error(f'获取用户信息失败, uid: {uid}, error: {_res.info}') + await pixiv_user_artwork.finish('获取用户信息失败了QAQ, 请稍后再试~') + up_name = _res.result.get('name') + state['up_name'] = up_name + msg = f'即将{sub_command}【{up_name}】的作品!' + await pixiv_user_artwork.send(msg) + + +@pixiv_user_artwork.got('check', prompt='确认吗?\n\n【是/否】') +async def handle_check(bot: Bot, event: MessageEvent, state: T_State): + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + else: + group_id = 'Private event' + + check_msg = state['check'] + uid = state['uid'] + if check_msg != '是': + await pixiv_user_artwork.finish('操作已取消') + sub_command = state['sub_command'] + if sub_command == '订阅': + _res = await sub_add(bot=bot, event=event, state=state) + elif sub_command == '取消订阅': + _res = await sub_del(bot=bot, event=event, state=state) + elif sub_command == '清空订阅': + _res = await sub_clear(bot=bot, event=event, state=state) + else: + _res = Result.IntResult(error=True, info='Unknown error, except sub_command', result=-1) + if _res.success(): + logger.info(f"{sub_command}Pixiv用户作品成功, {group_id} / {event.user_id}, uid: {uid}") + await pixiv_user_artwork.finish(f'{sub_command}成功!') + else: + logger.error(f"{sub_command}Pixiv用户作品失败, {group_id} / {event.user_id}, uid: {uid}," + f"info: {_res.info}") + await pixiv_user_artwork.finish(f'{sub_command}失败了QAQ, 可能并未订阅该用户, 或请稍后再试~') + + +async def sub_list(bot: Bot, event: MessageEvent, state: T_State) -> Result.ListResult: + self_bot = DBBot(self_qq=int(bot.self_id)) + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + result = await group.subscription_list_by_type(sub_type=9) + return result + elif isinstance(event, PrivateMessageEvent): + user_id = event.user_id + friend = DBFriend(user_id=user_id, self_bot=self_bot) + result = await friend.subscription_list_by_type(sub_type=9) + return result + else: + return Result.ListResult(error=True, info='Illegal event', result=[]) + + +async def sub_add(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) + uid = state['uid'] + sub = DBSubscription(sub_type=9, sub_id=uid) + need_init = not (await sub.exist()) + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + _res = await sub.add(up_name=state.get('up_name'), live_info='Pixiv用户作品订阅') + if not _res.success(): + return _res + # 初次订阅时立即刷新, 避免订阅后发送旧作品的问题 + if need_init: + await bot.send(event=event, message='初次订阅, 正在初始化作品信息, 可能需要1~2分钟, 请稍后...') + await init_new_add_sub(user_id=uid) + _res = await group.subscription_add(sub=sub, group_sub_info='Pixiv用户作品订阅') + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + elif isinstance(event, PrivateMessageEvent): + user_id = event.user_id + friend = DBFriend(user_id=user_id, self_bot=self_bot) + _res = await sub.add(up_name=state.get('up_name'), live_info='Pixiv用户作品订阅') + if not _res.success(): + return _res + # 初次订阅时立即刷新, 避免订阅后发送旧作品的问题 + if need_init: + await bot.send(event=event, message='初次订阅, 正在初始化作品信息, 可能需要1~2分钟, 请稍后...') + await init_new_add_sub(user_id=uid) + _res = await friend.subscription_add(sub=sub, user_sub_info='Pixiv用户作品订阅') + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + else: + return Result.IntResult(error=True, info='Illegal event', result=-1) + + +async def sub_del(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + uid = state['uid'] + _res = await group.subscription_del(sub=DBSubscription(sub_type=9, sub_id=uid)) + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + elif isinstance(event, PrivateMessageEvent): + user_id = event.user_id + friend = DBFriend(user_id=user_id, self_bot=self_bot) + uid = state['uid'] + _res = await friend.subscription_del(sub=DBSubscription(sub_type=9, sub_id=uid)) + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + else: + return Result.IntResult(error=True, info='Illegal event', result=-1) + + +async def sub_clear(bot: Bot, event: MessageEvent, state: T_State) -> Result.IntResult: + self_bot = DBBot(self_qq=int(bot.self_id)) + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + _res = await group.subscription_clear_by_type(sub_type=9) + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + elif isinstance(event, PrivateMessageEvent): + user_id = event.user_id + friend = DBFriend(user_id=user_id, self_bot=self_bot) + _res = await friend.subscription_clear_by_type(sub_type=9) + if not _res.success(): + return _res + result = Result.IntResult(error=False, info='Success', result=0) + return result + else: + return Result.IntResult(error=True, info='Illegal event', result=-1) diff --git a/omega_miya/plugins/pixiv_monitor/config.py b/omega_miya/plugins/pixiv_monitor/config.py new file mode 100644 index 00000000..a2a74d01 --- /dev/null +++ b/omega_miya/plugins/pixiv_monitor/config.py @@ -0,0 +1,23 @@ +""" +@Author : Ailitonia +@Date : 2021/08/04 0:10 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + + # plugin custom config + """ + 检查模式, 是否启用检查池模式 + """ + enable_check_pool_mode: bool = True + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/pixiv_monitor/monitor.py b/omega_miya/plugins/pixiv_monitor/monitor.py new file mode 100644 index 00000000..7180c799 --- /dev/null +++ b/omega_miya/plugins/pixiv_monitor/monitor.py @@ -0,0 +1,266 @@ +""" +@Author : Ailitonia +@Date : 2021/06/01 22:28 +@FileName : monitor.py +@Project : nonebot2_miya +@Description : Pixiv User Monitor +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import asyncio +import random +from nonebot import logger, require, get_bots, get_driver +from nonebot.adapters.cqhttp import MessageSegment, Message +from omega_miya.utils.Omega_Base import DBSubscription, DBPixivUserArtwork +from omega_miya.utils.pixiv_utils import PixivUser, PixivIllust +from omega_miya.utils.Omega_plugin_utils import MsgSender, PicEffector, PicEncoder, ProcessUtils +from .config import Config + + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +ENABLE_CHECK_POOL_MODE = plugin_config.enable_check_pool_mode + + +# 检查队列 +CHECKING_POOL = [] + + +# 启用检查Pixiv用户作品状态的定时任务 +scheduler = require("nonebot_plugin_apscheduler").scheduler + + +# 创建用于更新数据库里面画师名称的定时任务 +@scheduler.scheduled_job( + 'cron', + # year=None, + # month=None, + # day='*/1', + # week=None, + # day_of_week=None, + hour='1', + minute='15', + second='50', + # start_date=None, + # end_date=None, + # timezone=None, + id='pixiv_user_db_upgrade', + coalesce=True, + misfire_grace_time=60 +) +async def dynamic_db_upgrade(): + logger.debug('pixiv_user_db_upgrade: started upgrade subscription info') + sub_res = await DBSubscription.list_sub_by_type(sub_type=9) + for sub_id in sub_res.result: + sub = DBSubscription(sub_type=9, sub_id=sub_id) + user_info_result = await PixivUser(uid=sub_id).get_info() + if user_info_result.error: + logger.error(f'pixiv_user_db_upgrade: 获取用户信息失败, uid: {sub_id}, error: {user_info_result.info}') + continue + user_name = user_info_result.result.get('name') + _res = await sub.add(up_name=user_name, live_info='Pixiv用户作品订阅') + if not _res.success(): + logger.error(f'pixiv_user_db_upgrade: 更新用户信息失败, uid: {sub_id}, error: {_res.info}') + continue + logger.debug('pixiv_user_db_upgrade: upgrade subscription info completed') + + +# 创建Pixiv用户作品检查函数 +@scheduler.scheduled_job( + 'cron', + # year=None, + # month=None, + # day='*/1', + # week=None, + # day_of_week=None, + # hour=None, + minute='*/5', + # second='*/30', + # start_date=None, + # end_date=None, + # timezone=None, + id='pixiv_user_artwork_monitor', + coalesce=True, + misfire_grace_time=30 +) +async def pixiv_user_artwork_monitor(): + logger.debug(f"pixiv_user_artwork_monitor: checking started") + + # 获取当前bot列表 + bots = [bot for bot_id, bot in get_bots().items()] + + # 获取订阅表中的所有Pixiv用户订阅 + sub_res = await DBSubscription.list_sub_by_type(sub_type=9) + check_sub = [int(x) for x in sub_res.result] + + if not check_sub: + logger.debug(f'pixiv_user_artwork_monitor: no dynamic subscription, ignore.') + return + + # 注册一个异步函数用于检查Pixiv用户作品 + async def check_user_artwork(user_id: int): + # 获取pixiv用户作品内容 + user_artwork_result = await PixivUser(uid=user_id).get_artworks_info() + if user_artwork_result.error: + logger.error(f'pixiv_user_artwork_monitor: 获取用户 {user_id} 作品失败, error: {user_artwork_result.info}') + + all_artwork_list = user_artwork_result.result.get('illust_list') + manga_list = user_artwork_result.result.get('manga_list') + all_artwork_list.extend(manga_list) + + # 用户所有的作品id + exist_artwork_result = await DBPixivUserArtwork.list_artwork_by_uid(uid=user_id) + if exist_artwork_result.error: + logger.error(f'pixiv_user_artwork_monitor: 获取用户 {user_id} 已有作品失败, error: {exist_artwork_result.info}') + return + exist_artwork_list = [int(x) for x in exist_artwork_result.result] + + new_artwork = [pid for pid in all_artwork_list if pid not in exist_artwork_list] + + subscription = DBSubscription(sub_type=9, sub_id=user_id) + + for pid in new_artwork: + illust = PixivIllust(pid=pid) + illust_info_result = await illust.get_illust_data() + if illust_info_result.error: + logger.error(f'pixiv_user_artwork_monitor: 获取用户 {user_id} 作品 {pid} 信息失败, ' + f'error: {illust_info_result.info}') + continue + + uname = illust_info_result.result.get('uname') + title = illust_info_result.result.get('title') + is_r18 = illust_info_result.result.get('is_r18') + + # 下载图片 + illust_info_msg_result = await illust.get_format_info_msg() + illust_pic_bytes_result = await illust.load_illust_pic() + if illust_pic_bytes_result.error or illust_info_msg_result.error: + logger.error(f'pixiv_user_artwork_monitor: 下载用户 {user_id} 作品 {pid} 失败, ' + f'error: {illust_info_msg_result.info} // {illust_pic_bytes_result.info}.') + continue + + if is_r18: + # 添加图片处理模块, 模糊后发送 + blur_img_result = await PicEffector(image=illust_pic_bytes_result.result).gaussian_blur() + b64_img_result = await PicEncoder.bytes_to_file( + image=blur_img_result.result, folder_flag='pixiv_monitor') + img_seg = MessageSegment.image(b64_img_result.result) + else: + b64_img_result = await PicEncoder.bytes_to_file( + image=illust_pic_bytes_result.result, folder_flag='pixiv_monitor') + img_seg = MessageSegment.image(b64_img_result.result) + + intro_msg = f'【Pixiv】{uname}发布了新的作品!\n\n' + info_msg = illust_info_msg_result.result + msg = Message(intro_msg).append(img_seg).append(info_msg) + + # 向群组和好友推送消息 + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='PixivUserArtworkNotice') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=msg) + # await msg_sender.safe_broadcast_friends_subscription(subscription=subscription, message=msg) + + # 更新作品内容到数据库 + pixiv_user_artwork = DBPixivUserArtwork(pid=pid, uid=user_id) + _res = await pixiv_user_artwork.add(uname=uname, title=title) + if _res.success(): + logger.info(f'向数据库写入pixiv用户作品信息: {pid} 成功') + else: + logger.error(f'向数据库写入pixiv用户作品信息: {pid} 失败, error: {_res.info}') + + # 启用了检查池模式 + if ENABLE_CHECK_POOL_MODE: + global CHECKING_POOL + # checking_pool为空则上一轮检查完了, 重新往里面放新一轮的uid + if not CHECKING_POOL: + CHECKING_POOL.extend(check_sub) + + # 看下checking_pool里面还剩多少 + waiting_num = len(CHECKING_POOL) + + # 默认单次检查并发数为50, 默认检查间隔为5min + logger.debug(f'Pixiv user artwork checker pool mode debug info, Before checking_pool: {CHECKING_POOL}') + if waiting_num >= 50: + # 抽取检查对象 + now_checking = random.sample(CHECKING_POOL, k=50) + # 更新checking_pool + CHECKING_POOL = [x for x in CHECKING_POOL if x not in now_checking] + else: + now_checking = CHECKING_POOL.copy() + CHECKING_POOL.clear() + logger.debug(f'Pixiv user artwork checker pool mode debug info, After checking_pool: {CHECKING_POOL}') + logger.debug(f'Pixiv user artwork checker pool mode debug info, now_checking: {now_checking}') + + # 检查now_checking里面的直播间(异步) + tasks = [] + for uid in now_checking: + tasks.append(check_user_artwork(user_id=uid)) + try: + await asyncio.gather(*tasks) + logger.debug(f"pixiv_user_artwork_monitor: pool mode enable, checking completed, " + f"checked: {', '.join([str(x) for x in now_checking])}.") + except Exception as e: + logger.error(f'pixiv_user_artwork_monitor: error occurred in checking {repr(e)}') + + # 没有启用检查池模式 + else: + # 检查所有在订阅表里面的画师作品(异步) + tasks = [] + for uid in check_sub: + tasks.append(check_user_artwork(user_id=uid)) + try: + await asyncio.gather(*tasks) + logger.debug(f"pixiv_user_artwork_monitor: pool mode disable, checking completed, " + f"checked: {', '.join([str(x) for x in check_sub])}.") + except Exception as e: + logger.error(f'pixiv_user_artwork_monitor: error occurred in checking {repr(e)}') + + +# 用于首次订阅时刷新数据库信息 +async def init_new_add_sub(user_id: int): + # 暂停计划任务避免中途检查更新 + scheduler.pause() + try: + # 获取pixiv用户作品内容 + user_artwork_result = await PixivUser(uid=user_id).get_artworks_info() + if user_artwork_result.error: + logger.error(f'init_new_add_sub: 获取用户 {user_id} 作品失败, error: {user_artwork_result.info}') + + all_artwork_list = user_artwork_result.result.get('illust_list') + manga_list = user_artwork_result.result.get('manga_list') + all_artwork_list.extend(manga_list) + + async def _handle(pid_: int): + illust = PixivIllust(pid=pid_) + illust_info_result = await illust.get_illust_data() + if illust_info_result.error: + logger.error(f'init_new_add_sub: 获取用户 {user_id} 作品 {pid_} 信息失败, error: {illust_info_result.info}') + return + + uname = illust_info_result.result.get('uname') + title = illust_info_result.result.get('title') + + # 更新作品内容到数据库 + pixiv_user_artwork = DBPixivUserArtwork(pid=pid_, uid=user_id) + _res = await pixiv_user_artwork.add(uname=uname, title=title) + if _res.success(): + logger.debug(f'向数据库写入pixiv用户作品信息: {pid_} 成功') + else: + logger.error(f'向数据库写入pixiv用户作品信息: {pid_} 失败, error: {_res.info}') + + # 开始导入操作 + # 全部一起并发网络撑不住, 做适当切分 + tasks = [_handle(pid_=pid) for pid in all_artwork_list] + await ProcessUtils.fragment_process(tasks=tasks, fragment_size=50, log_flag='Init Pixiv User Illust') + logger.info(f'初始化pixiv用户 {user_id} 作品完成, 已将作品信息写入数据库.') + except Exception as e: + logger.error(f'初始化pixiv用户 {user_id} 作品发生错误, error: {repr(e)}.') + + scheduler.resume() + + +__all__ = [ + 'scheduler', + 'init_new_add_sub' +] diff --git a/omega_miya/plugins/pixivsion_monitor/__init__.py b/omega_miya/plugins/pixivsion_monitor/__init__.py index be1ae058..bc706237 100644 --- a/omega_miya/plugins/pixivsion_monitor/__init__.py +++ b/omega_miya/plugins/pixivsion_monitor/__init__.py @@ -4,9 +4,9 @@ from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import GroupMessageEvent from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER -from omega_miya.utils.Omega_Base import DBGroup, DBSubscription, Result +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBSubscription, Result from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state -from .monitor import * +from .monitor import scheduler, init_pixivsion_article # Custom plugin usage text @@ -17,14 +17,23 @@ **Permission** Command & Lv.30 +or AuthNode + +**AuthNode** +basic **Usage** **GroupAdmin and SuperUser Only** /Pixivision 订阅 /Pixivision 取消订阅''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) # 注册事件响应器 pixivision = on_command( @@ -34,7 +43,8 @@ state=init_permission_state( name='pixivision', command=True, - level=30), + level=30, + auth_node='basic'), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=20, block=True) @@ -84,13 +94,19 @@ async def handle_sub_command_args(bot: Bot, event: GroupMessageEvent, state: T_S async def sub_add(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) sub_id = -1 sub = DBSubscription(sub_type=8, sub_id=sub_id) + need_init = not (await sub.exist()) _res = await sub.add(up_name='Pixivision', live_info='Pixivision订阅') if not _res.success(): return _res - _res = await group.subscription_add(sub=sub) + # 初次订阅时立即刷新, 避免订阅后发送旧特辑的问题 + if need_init: + await bot.send(event=event, message='初次订阅, 正在初始化Pixivision信息, 可能需要1~2分钟, 请稍后...') + await init_pixivsion_article() + _res = await group.subscription_add(sub=sub, group_sub_info='Pixivision订阅') if not _res.success(): return _res result = Result.IntResult(error=False, info='Success', result=0) @@ -99,7 +115,8 @@ async def sub_add(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result. async def sub_del(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: group_id = event.group_id - group = DBGroup(group_id=group_id) + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) sub_id = -1 _res = await group.subscription_del(sub=DBSubscription(sub_type=8, sub_id=sub_id)) if not _res.success(): diff --git a/omega_miya/plugins/pixivsion_monitor/config.py b/omega_miya/plugins/pixivsion_monitor/config.py new file mode 100644 index 00000000..a18b6499 --- /dev/null +++ b/omega_miya/plugins/pixivsion_monitor/config.py @@ -0,0 +1,21 @@ +""" +@Author : Ailitonia +@Date : 2021/06/12 20:45 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + # plugin custom config + # 启用使用群组转发自定义消息节点的模式发送信息 + # 发送速度受限于网络上传带宽, 有可能导致超时或发送失败, 请酌情启用 + enable_node_custom: bool = False + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/pixivsion_monitor/monitor.py b/omega_miya/plugins/pixivsion_monitor/monitor.py index dcd3d847..55eee764 100644 --- a/omega_miya/plugins/pixivsion_monitor/monitor.py +++ b/omega_miya/plugins/pixivsion_monitor/monitor.py @@ -1,10 +1,17 @@ import asyncio -from nonebot import logger, require, get_bots +from nonebot import logger, require, get_bots, get_driver from nonebot.adapters.cqhttp import MessageSegment -from omega_miya.utils.Omega_Base import DBSubscription, DBTable +from omega_miya.utils.Omega_Base import DBSubscription, DBPixivision +from omega_miya.utils.Omega_plugin_utils import MsgSender from omega_miya.utils.pixiv_utils import PixivIllust, PixivisionArticle from .utils import pixivsion_article_parse from .block_tag import TAG_BLOCK_LIST +from .config import Config + + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +ENABLE_NODE_CUSTOM = plugin_config.enable_node_custom # 启用检查动态状态的定时任务 @@ -32,14 +39,7 @@ async def pixivision_monitor(): logger.debug(f"pixivision_monitor: checking started") # 获取当前bot列表 - bots = [] - for bot_id, bot in get_bots().items(): - bots.append(bot) - - # 获取所有有通知权限的群组 - t = DBTable(table_name='Group') - group_res = await t.list_col_with_condition('group_id', 'notice_permissions', 1) - all_noitce_groups = [int(x) for x in group_res.result] + bots = [bot for bot_id, bot in get_bots().items()] # 初始化tag黑名单 block_tag_id = [] @@ -49,8 +49,7 @@ async def pixivision_monitor(): block_tag_name.append(block_tag.get('name')) # 提取数据库中已有article的id列表 - t = DBTable(table_name='Pixivision') - pixivision_res = await t.list_col(col_name='aid') + pixivision_res = await DBPixivision.list_article_id() exist_article = [int(x) for x in pixivision_res.result] # 获取最新一页pixivision的article @@ -83,12 +82,7 @@ async def pixivision_monitor(): logger.info(f'pixivision_monitor: checking completed, 没有新的article') return - sub = DBSubscription(sub_type=8, sub_id=-1) - # 获取订阅了该直播间的所有群 - sub_group_res = await sub.sub_group_list() - sub_group = sub_group_res.result - # 需通知的群 - notice_group = list(set(all_noitce_groups) & set(sub_group)) + subscription = DBSubscription(sub_type=8, sub_id=-1) # 处理新的aritcle for article in new_article: @@ -96,45 +90,86 @@ async def pixivision_monitor(): tags = list(article['tags']) a_res = await pixivsion_article_parse(aid=aid, tags=tags) if a_res.success(): - if not notice_group: - continue article_data = a_res.result msg = f"新的Pixivision特辑!\n\n" \ f"《{article_data['title']}》\n\n{article_data['description']}\n{article_data['url']}" - for group_id in notice_group: - for _bot in bots: - try: - await _bot.call_api(api='send_group_msg', group_id=group_id, message=msg) - except Exception as e: - logger.warning(f"向群组: {group_id} 发送article简介信息失败, error: {repr(e)}") - continue + + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='NewPixivisionArticle') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=msg) + # 处理article中图片内容 tasks = [] for pid in article_data['illusts_list']: - tasks.append(PixivIllust(pid=pid).pic_2_base64()) + tasks.append(PixivIllust(pid=pid).get_file()) p_res = await asyncio.gather(*tasks) image_error = 0 - for image_res in p_res: - if not image_res.success(): - image_error += 1 - continue - else: - img_seg = MessageSegment.image(image_res.result) - # 发送图片 - for group_id in notice_group: + + if ENABLE_NODE_CUSTOM: + node_messages = [] + for image_res in p_res: + if not image_res.success(): + image_error += 1 + continue + # 构造自定义消息节点 + node_messages.append(MessageSegment.image(image_res.result)) + # 发送消息 + for _bot in bots: + msg_sender = MsgSender(bot=_bot, log_flag='NewPixivisionImage') + await msg_sender.safe_broadcast_groups_subscription_node_custom( + subscription=subscription, message_list=node_messages) + else: + for image_res in p_res: + if not image_res.success(): + image_error += 1 + continue + else: + img_seg = MessageSegment.image(image_res.result) + # 发送图片 for _bot in bots: - try: - await _bot.call_api(api='send_group_msg', group_id=group_id, message=img_seg) - # 避免风控控制推送间隔 - await asyncio.sleep(1) - except Exception as e: - logger.warning(f"向群组: {group_id} 发送图片内容失败, error: {repr(e)}") - continue + msg_sender = MsgSender(bot=_bot, log_flag='NewPixivisionImage') + await msg_sender.safe_broadcast_groups_subscription(subscription=subscription, message=img_seg) + logger.info(f"article: {aid} 图片已发送完成, 失败: {image_error}") else: logger.error(f"article: {aid} 信息解析失败, info: {a_res.info}") logger.info(f'pixivision_monitor: checking completed, 已处理新的article: {repr(new_article)}') + +# 用于首次订阅时刷新数据库信息 +async def init_pixivsion_article(): + # 暂停计划任务避免中途检查更新 + scheduler.pause() + try: + # 获取最新一页pixivision的article + new_article = [] + articles_result = await PixivisionArticle.get_illustration_list() + if articles_result.error: + return + for article in articles_result.result: + try: + article = dict(article) + article_tags_id = [] + article_tags_name = [] + for tag in article['tags']: + article_tags_id.append(int(tag['tag_id'])) + article_tags_name.append(str(tag['tag_name'])) + new_article.append({'aid': int(article['id']), 'tags': article_tags_name}) + except Exception: + continue + # 处理新的aritcle + for article in new_article: + aid = int(article['aid']) + tags = list(article['tags']) + await pixivsion_article_parse(aid=aid, tags=tags) + except Exception as e: + logger.info(f'初始化pixivsion article错误, error: {repr(e)}.') + + scheduler.resume() + logger.info(f'初始化pixivsion article完成, 已将作品信息写入数据库.') + + __all__ = [ - 'scheduler' + 'scheduler', + 'init_pixivsion_article' ] diff --git a/omega_miya/plugins/repeater/__init__.py b/omega_miya/plugins/repeater/__init__.py index 42a176d7..7a0041c4 100644 --- a/omega_miya/plugins/repeater/__init__.py +++ b/omega_miya/plugins/repeater/__init__.py @@ -1,55 +1,69 @@ -import re +from typing import Dict from nonebot import on_message from nonebot.typing import T_State +from nonebot.exception import FinishedException from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import GroupMessageEvent from nonebot.adapters.cqhttp.permission import GROUP -from omega_miya.utils.Omega_plugin_utils import has_notice_permission -from .utils import sp_event_check +from omega_miya.utils.Omega_plugin_utils import OmegaRules +from .data_source import REPLY_RULES -last_msg = {} -last_repeat_msg = {} -repeat_count = {} -repeater = on_message(rule=has_notice_permission(), permission=GROUP, priority=100, block=False) +LAST_MSG: Dict[int, str] = {} +LAST_REPEAT_MSG: Dict[int, str] = {} +REPEAT_COUNT: Dict[int, int] = {} + +repeater = on_message(rule=OmegaRules.has_group_command_permission(), permission=GROUP, priority=100, block=False) @repeater.handle() -async def handle_repeater(bot: Bot, event: GroupMessageEvent, state: T_State): +async def handle_ignore_msg(bot: Bot, event: GroupMessageEvent, state: T_State): + msg = event.raw_message + if msg.startswith('/'): + raise FinishedException + elif msg.startswith('!SU'): + raise FinishedException + + +@repeater.handle() +async def handle_auto_reply(bot: Bot, event: GroupMessageEvent, state: T_State): + # 处理回复规则 + msg = event.raw_message group_id = event.group_id + check_res, reply_msg = REPLY_RULES.check_rule(group_id=group_id, message=msg) + if check_res: + await repeater.finish(reply_msg) - global last_msg, last_repeat_msg, repeat_count +@repeater.handle() +async def handle_repeater(bot: Bot, event: GroupMessageEvent, state: T_State): + # 处理复读姬 + global LAST_MSG, LAST_REPEAT_MSG, REPEAT_COUNT + group_id = event.group_id try: - last_msg[group_id] + LAST_MSG[group_id] except KeyError: - last_msg[group_id] = '' + LAST_MSG[group_id] = '' try: - last_repeat_msg[group_id] + LAST_REPEAT_MSG[group_id] except KeyError: - last_repeat_msg[group_id] = '' - - # 特殊消息 - sp_res, sp_msg = await sp_event_check(event=event) - if sp_res: - repeat_count[group_id] = 0 - await repeater.finish(message=sp_msg) + LAST_REPEAT_MSG[group_id] = '' - t_msg = event.message - msg = event.raw_message - - if re.match(r'^/', msg): - return + message = event.get_message() + raw_msg = event.raw_message - if msg != last_msg[group_id] or msg == last_repeat_msg[group_id]: - last_msg[group_id] = msg - repeat_count[group_id] = 0 + # 如果当前消息与上一条消息不同, 或者与上一次复读的消息相同, 则重置复读计数, 并更新上一条消息LAST_MSG + if raw_msg != LAST_MSG[group_id] or raw_msg == LAST_REPEAT_MSG[group_id]: + LAST_MSG[group_id] = raw_msg + REPEAT_COUNT[group_id] = 0 return + # 否则这条消息和上条消息一致, 开始复读计数 else: - repeat_count[group_id] += 1 - last_repeat_msg[group_id] = '' - if repeat_count[group_id] >= 2: - await repeater.send(t_msg) - repeat_count[group_id] = 0 - last_msg[group_id] = '' - last_repeat_msg[group_id] = msg + REPEAT_COUNT[group_id] += 1 + LAST_REPEAT_MSG[group_id] = '' + # 当复读计数等于2时说明已经有连续三条同样的消息了, 此时触发复读, 更新上次服务消息LAST_REPEAT_MSG, 并重置复读计数 + if REPEAT_COUNT[group_id] >= 2: + await repeater.send(message) + REPEAT_COUNT[group_id] = 0 + LAST_MSG[group_id] = '' + LAST_REPEAT_MSG[group_id] = raw_msg diff --git a/omega_miya/plugins/repeater/data_source.py b/omega_miya/plugins/repeater/data_source.py new file mode 100644 index 00000000..42071a2f --- /dev/null +++ b/omega_miya/plugins/repeater/data_source.py @@ -0,0 +1,78 @@ +""" +@Author : Ailitonia +@Date : 2021/06/11 22:39 +@FileName : data_source.py +@Project : nonebot2_miya +@Description : Auto-Reply utils +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import os +import re +import pathlib +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union +from nonebot.adapters.cqhttp.message import Message, MessageSegment + +RESOURCE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'resources')) + + +class ResourceMsg(object): + def __init__(self, resource_name: str): + self.resource_name: str = resource_name + + def img_msg(self) -> MessageSegment: + img_file_path = os.path.abspath(os.path.join(RESOURCE_PATH, self.resource_name)) + file_url = pathlib.Path(img_file_path).as_uri() + return MessageSegment.image(file=file_url) + + def record_msg(self) -> MessageSegment: + record_file_path = os.path.abspath(os.path.join(RESOURCE_PATH, self.resource_name)) + file_url = pathlib.Path(record_file_path).as_uri() + return MessageSegment.record(file=file_url) + + +@dataclass +class Reply: + group_id: List[int] + handle: bool + reply_msg: Union[str, Message, MessageSegment] + + +@dataclass +class ReplyRules: + rules: Dict[str, Reply] + + def check_rule(self, group_id: int, message: str) -> Tuple[bool, Union[str, Message, MessageSegment]]: + for regular, reply in self.rules.items(): + try: + # 使用正则规格匹配消息 + if re.match(regular, message): + # 使用回复群组限制匹配是否回复群组, 空则为无限制 + if not reply.group_id or group_id in reply.group_id: + # 判断回复消息是否是需要处理占位符的信息, 目前暂时只支持单组匹配及占位符填充 + if reply.handle: + reply_msg = reply.reply_msg.format(re.findall(regular, message)[0]) + else: + reply_msg = reply.reply_msg + # 按顺序匹配中立即返回, 忽略后续规则 + return True, reply_msg + except Exception: + continue + return False, '' + + +REPLY_RULES: ReplyRules = ReplyRules(rules={ + r'(.+)好萌好可爱$': Reply(group_id=[], handle=True, reply_msg=r'我也觉得{}好萌好可爱~'), + r'^#测试群友(.+)浓度#?$': Reply(group_id=[], handle=True, reply_msg=r'群友{}浓度已超出测量范围Σ(っ °Д °;)っ'), + r'^对呀对呀$': Reply(group_id=[], handle=False, reply_msg=r'对呀对呀~'), + r'^小母猫$': Reply(group_id=[], handle=False, reply_msg=r'喵喵喵~'), + r'^优质(解答|回答)(\.jpg)?$': Reply(group_id=[], handle=False, reply_msg=ResourceMsg('good_answer.jpg').img_msg()), + r'^[Dd]{2}们都是变态吗[\??]?$': Reply(group_id=[], handle=False, reply_msg=r'你好,是的') +}) + + +__all__ = [ + 'REPLY_RULES' +] diff --git a/omega_miya/plugins/repeater/img_res/good_answer.jpg b/omega_miya/plugins/repeater/resources/good_answer.jpg similarity index 100% rename from omega_miya/plugins/repeater/img_res/good_answer.jpg rename to omega_miya/plugins/repeater/resources/good_answer.jpg diff --git a/omega_miya/plugins/repeater/utils.py b/omega_miya/plugins/repeater/utils.py deleted file mode 100644 index 232c5762..00000000 --- a/omega_miya/plugins/repeater/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import re -import os -from nonebot.adapters import Event -from nonebot.adapters.cqhttp import Message - - -def img_message(img_name: str) -> Message: - img_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'img_res', img_name)) - return Message(f"[CQ:image,file=file:///{img_path}]") - - -sp_msg = { - r'(.+)好萌好可爱$': - {'group_id': [], 'replyMsg': r'我也觉得{}好萌好可爱', 'handle': True}, - r'^#测试群友(.+)浓度#?$': - {'group_id': [], 'replyMsg': r'群友{}浓度已超出测量范围Σ(っ °Д °;)っ', 'handle': True}, - r'^对呀对呀$': - {'group_id': [], 'replyMsg': r'对呀对呀', 'handle': False}, - r'^小母猫': - {'group_id': [], 'replyMsg': r'喵喵喵~', 'handle': False}, - r'^优质(解答|回答)(\.jpg)?$': - {'group_id': [], 'replyMsg': img_message('good_answer.jpg'), 'handle': False}, - r'^[Dd]{2}们都是变态吗[\??]?$': - {'group_id': [], 'replyMsg': r'你好,是的', 'handle': False}, -} - - -async def sp_event_check(event: Event) -> (bool, str): - msg = str(event.get_message()) - group_id = event.dict().get('group_id') - for key in sp_msg.keys(): - if re.match(key, msg): - if group_id in sp_msg.get(key).get('group_id') or not sp_msg.get(key).get('group_id'): - handle = sp_msg.get(key).get('handle') - if handle: - msg = sp_msg.get(key).get('replyMsg').format(re.findall(key, msg)[0]) - return True, msg - else: - msg = sp_msg.get(key).get('replyMsg') - return True, msg - return False, '' diff --git a/omega_miya/plugins/roll/__init__.py b/omega_miya/plugins/roll/__init__.py index 012c602e..d112776d 100644 --- a/omega_miya/plugins/roll/__init__.py +++ b/omega_miya/plugins/roll/__init__.py @@ -17,13 +17,23 @@ **Permission** Command & Lv.10 +or AuthNode + +**AuthNode** +basic **Usage** /roll d /抽奖 <人数>''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + Roll = CommandGroup( 'R0ll', @@ -31,7 +41,8 @@ state=init_permission_state( name='roll', command=True, - level=10), + level=10, + auth_node='basic'), permission=GROUP, priority=10, block=True) @@ -65,72 +76,31 @@ async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_Stat async def handle_roll(bot: Bot, event: GroupMessageEvent, state: T_State): _roll = state['roll'] if re.match(r'^\d+[d]\d+$', _roll): + # d dice_info = _roll.split('d') dice_num = int(dice_info[0]) dice_side = int(dice_info[1]) - # 加入一个趣味的机制 - if random.randint(1, 100) == 99: - await roll.finish(f'【彩蛋】骰子之神似乎不看好你, 你掷出的骰子全部消失了') - if dice_num > 1000 or dice_side > 1000: - await roll.finish(f'【错误】谁没事干扔那么多骰子啊(╯°□°)╯︵ ┻━┻') - if dice_num <= 0 or dice_side <= 0: - await roll.finish(f'【错误】你掷出了不存在的骰子, 只有上帝知道结果是多少') - dice_result = 0 - for i in range(dice_num): - this_dice_result = random.choice(range(dice_side)) + 1 - dice_result += this_dice_result - await roll.finish(f'你掷出了{dice_num}个{dice_side}面骰子, 点数为【{dice_result}】') + elif re.match(r'^[d]\d+$', _roll): + # d + dice_num = 1 + dice_side = int(_roll[1:]) + elif re.match(r'^\d+$', _roll): + # Any number + dice_num = 1 + dice_side = int(_roll) else: await roll.finish(f'格式不对呢, 请重新输入: /roll d:') - - -lottery = Roll.command('lottery', aliases={'抽奖'}) - - -# 修改默认参数处理 -@lottery.args_parser -async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): - args = str(event.get_plaintext()).strip().lower().split() - if not args: - await lottery.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') - state[state["_current_key"]] = args[0] - if state[state["_current_key"]] == '取消': - await lottery.finish('操作已取消') - - -@lottery.handle() -async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): - args = str(event.get_plaintext()).strip().lower().split() - if not args: - pass - elif args and len(args) == 1: - state['lottery'] = args[0] - else: - await lottery.finish('参数错误QAQ') - - -@lottery.got('lottery', prompt='请输入抽奖人数') -async def handle_lottery(bot: Bot, event: GroupMessageEvent, state: T_State): - _lottery = state['lottery'] - if re.match(r'^\d+$', _lottery): - people_num = int(_lottery) - - group_member_list = await bot.call_api(api='get_group_member_list', group_id=event.group_id) - group_user_name_list = [] - - for user_info in group_member_list: - # 用户信息 - user_nickname = user_info['nickname'] - user_group_nickmane = user_info['card'] - if not user_group_nickmane: - user_group_nickmane = user_nickname - group_user_name_list.append(user_group_nickmane) - - if people_num > len(group_user_name_list): - await lottery.finish(f'【错误】抽奖人数大于群成员人数了QAQ') - - lottery_result = random.sample(group_user_name_list, k=people_num) - msg = '【' + str.join('】\n【', lottery_result) + '】' - await lottery.finish(f"抽奖人数: 【{people_num}】\n以下是中奖名单:\n{msg}") - else: - await lottery.finish(f'格式不对呢, 人数应该是数字') + return + + # 加入一个趣味的机制 + if random.randint(1, 100) == 99: + await roll.finish(f'【彩蛋】骰子之神似乎不看好你, 你掷出的骰子全部消失了') + if dice_num > 1024 or dice_side > 1024: + await roll.finish(f'【错误】谁没事干扔那么多骰子啊(╯°□°)╯︵ ┻━┻') + if dice_num <= 0 or dice_side <= 0: + await roll.finish(f'【错误】你掷出了不存在的骰子, 只有上帝知道结果是多少') + dice_result = 0 + for i in range(dice_num): + this_dice_result = random.choice(range(dice_side)) + 1 + dice_result += this_dice_result + await roll.finish(f'你掷出了{dice_num}个{dice_side}面骰子, 点数为【{dice_result}】') diff --git a/omega_miya/plugins/schedule_message/__init__.py b/omega_miya/plugins/schedule_message/__init__.py new file mode 100644 index 00000000..0da97877 --- /dev/null +++ b/omega_miya/plugins/schedule_message/__init__.py @@ -0,0 +1,388 @@ +""" +@Author : Ailitonia +@Date : 2021/06/22 20:38 +@FileName : schedule_message.py +@Project : nonebot2_miya +@Description : 定时消息插件 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import nonebot +import re +from datetime import datetime +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from nonebot import MatcherGroup, export, logger, require +from nonebot.typing import T_State +from nonebot.permission import SUPERUSER +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.message import Message +from nonebot.adapters.cqhttp.event import GroupMessageEvent +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, Result +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state + + +# Custom plugin usage text +__plugin_name__ = '定时消息' +__plugin_usage__ = r'''【定时消息】 +设置群组定时通知消息 +仅限群聊使用 + +**Permission** +Command & Lv.10 +or AuthNode + +**AuthNode** +basic + +**Usage** +**GroupAdmin and SuperUser Only** +/设置定时消息 +/查看定时消息 +/删除定时消息''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + +driver = nonebot.get_driver() +scheduler: AsyncIOScheduler = require("nonebot_plugin_apscheduler").scheduler + +# 注册事件响应器 +ScheduleMsg = MatcherGroup( + type='message', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='schedule_message', + command=True, + level=10, + auth_node='basic'), + permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, + priority=10, + block=True) + + +set_schedule_message = ScheduleMsg.on_command('设置定时消息', aliases={'添加定时消息'}) +list_schedule_message = ScheduleMsg.on_command('查看定时消息') +del_schedule_message = ScheduleMsg.on_command('删除定时消息', aliases={'移除定时消息'}) + + +# 设置定时消息部分 +# 修改默认参数处理 +@set_schedule_message.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_message()).strip() + if not args: + await set_schedule_message.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args + if state[state["_current_key"]] == '取消': + await set_schedule_message.finish('操作已取消') + + +@set_schedule_message.got('mode', prompt='请发送设置定时消息的模式:\n【 cron / interval 】\n\n模式说明:\n' + 'cron(闹钟) - 每天某个具体时间发送消息\n' + 'interval(定时器) - 每间隔一定时间发送消息') +async def handle_mode(bot: Bot, event: GroupMessageEvent, state: T_State): + mode = state['mode'] + if mode not in ['cron', 'interval']: + await set_schedule_message.finish('您发送的不是有效的模式QAQ') + if mode == 'interval': + state['repeat'] = 'all' + + +@set_schedule_message.got('name', prompt='请发送为当前定时任务设置的名称:') +async def handle_time(bot: Bot, event: GroupMessageEvent, state: T_State): + _name = state['name'] + if len(_name) > 100: + await set_schedule_message.finish('设置的名称过长QAQ') + + +@set_schedule_message.got('time', prompt='请发送你要设置定时时间, 时间格式为24小时制四位数字:\n\n设置说明:\n' + '若模式为cron(闹钟), 则“1830”代表每天下午六点半发送定时消息\n' + '若模式为interval(定时器), 则“0025”代表每隔25分钟发送定时消息') +async def handle_time(bot: Bot, event: GroupMessageEvent, state: T_State): + time = state['time'] + mode = state['mode'] + try: + _time = datetime.strptime(time, '%H%M') + _hour = _time.hour + _minute = _time.minute + except ValueError: + await set_schedule_message.finish('输入的时间格式错误QAQ, 应该为24小时制四位数字') + return + if mode == 'interval' and _hour == 0 and _minute == 0: + await set_schedule_message.finish('输入的时间格式错误QAQ, interval模式不允许时间为0000') + return + state['hour'] = _hour + state['minute'] = _minute + + +@set_schedule_message.got('repeat', prompt='是否按星期重复?\n\n若只想在一周的某一天执行请以下日期中选择:\n' + '【mon/tue/wed/thu/fri/sat/sun】\n\n' + '若想每一天都执行请输入:\n【all】') +async def handle_time(bot: Bot, event: GroupMessageEvent, state: T_State): + repeat = state['repeat'] + if repeat not in ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun', 'all']: + await set_schedule_message.finish('输入的日期格式错误QAQ, 请在【mon/tue/wed/thu/fri/sat/sun/all】中选择输入') + + +@set_schedule_message.got('message', prompt='请发送你要设置的消息内容:') +async def handle_message(bot: Bot, event: GroupMessageEvent, state: T_State): + message = state['message'] + name = state['name'] + mode = state['mode'] + hour = state['hour'] + minute = state['minute'] + repeat = state['repeat'] + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + try: + await add_scheduler( + group=group, schedule_name=name, mode=mode, hour=hour, minute=minute, repeat=repeat, message=message) + except Exception as e: + logger.error(f'为群组: {group_id} 设置群组定时消息失败任务, 添加计划任务时发生错误: {repr(e)}') + await set_schedule_message.finish(f'为本群组设定群组定时消息失败了QAQ, 请稍后再试或联系管理员处理') + + msg_set_result = await add_db_group_schedule_message( + group=group, schedule_name=name, mode=mode, hour=hour, minute=minute, repeat=repeat, message=message) + + if msg_set_result.success(): + logger.info(f'已为群组: {group_id} 设置群组定时消息: {name}{mode}/{hour}:{minute}') + await set_schedule_message.finish(f'已为本群组设定了群组定时消息:\n{name}/{mode}/{repeat}:{hour}:{minute}') + else: + logger.error(f'为群组: {group_id} 设置群组定时消息失败, error info: {msg_set_result.info}') + await set_schedule_message.finish(f'为本群组设定了群组定时消息失败了QAQ, 请稍后再试或联系管理员处理') + + +# 查看定时消息部分 +@list_schedule_message.handle() +async def handle(bot: Bot, event: GroupMessageEvent, state: T_State): + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=event.group_id, self_bot=self_bot) + schedule_result = await list_db_group_schedule_message(group=group) + if schedule_result.error: + logger.error(f'Get group {event.group_id} message schedule list failed: {schedule_result.info}') + await list_schedule_message.finish(f'获取群定时消息失败了QAQ, 请稍后再试或联系管理员处理') + msg = f'本群已设置的定时消息任务:\n{"="*12}' + for _name, _mode, _time, _message in schedule_result.result: + _name = re.sub(r'^ScheduleMsg_', '', str(_name)) + msg += f'\n【{_name}】 - {_mode}({_time})' + await list_schedule_message.finish(msg) + + +# 删除定时消息部分 +# 修改默认参数处理 +@del_schedule_message.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_message()).strip() + if not args: + await del_schedule_message.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args + if state[state["_current_key"]] == '取消': + await del_schedule_message.finish('操作已取消') + + +@del_schedule_message.handle() +async def handle_jobs(bot: Bot, event: GroupMessageEvent, state: T_State): + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=event.group_id, self_bot=self_bot) + schedule_result = await list_db_group_schedule_message(group=group) + if schedule_result.error: + logger.error(f'Get group {event.group_id} message schedule list failed: {schedule_result.info}') + await list_schedule_message.finish(f'获取群定时消息列表失败了QAQ, 请稍后再试或联系管理员处理') + msg = f'本群已设置的定时消息任务有:\n{"="*12}' + for _name, _mode, _time, _message in schedule_result.result: + _name = re.sub(r'^ScheduleMsg_', '', str(_name)) + msg += f'\n【{_name}】 - {_mode}({_time})' + await list_schedule_message.send(msg) + + +@del_schedule_message.got('name', prompt='请发送将要移除的定时任务的名称:') +async def handle_remove(bot: Bot, event: GroupMessageEvent, state: T_State): + name = state['name'] + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + try: + await remove_scheduler(group=group, schedule_name=name) + except Exception as e: + logger.error(f'移除群组: {group_id} 定时消息失败, 移除计划任务时发生错误: {repr(e)}') + await del_schedule_message.finish(f'移除群组定时消息失败了QAQ, 请稍后再试或联系管理员处理') + + msg_del_result = await del_db_group_schedule_message(group=group, schedule_name=name) + + if msg_del_result.success(): + logger.info(f'已移除群组: {group_id} 群组定时消息: {name}') + await del_schedule_message.finish(f'已移除群组定时消息: {name}') + else: + logger.error(f'移除群组: {group_id} 群组定时消息失败, error info: {msg_del_result.info}') + await del_schedule_message.finish(f'移除群组定时消息失败了QAQ, 请稍后再试或联系管理员处理') + + +async def add_db_group_schedule_message( + group: DBBotGroup, + schedule_name: str, + mode: str, + hour: int, + minute: int, + repeat: str, + message: str) -> Result.IntResult: + # 初始化计划任务设置ID + _schedule_setting_id = f'ScheduleMsg_{schedule_name}' + schedule_set_result = await group.setting_set(setting_name=_schedule_setting_id, main_config=mode, + secondary_config=f'{repeat}:{hour}:{minute}', + extra_config=message, setting_info='群组定时消息') + return schedule_set_result + + +async def list_db_group_schedule_message(group: DBBotGroup) -> Result.ListResult: + exist_setting = await group.setting_list() + if exist_setting.error: + return Result.ListResult(error=True, info=f'Get config wrong: {exist_setting.info}', result=[]) + else: + result = [x for x in exist_setting.result if str(x[0]).startswith('ScheduleMsg_')] + return Result.ListResult(error=False, info=f'Success', result=result) + + +async def del_db_group_schedule_message(group: DBBotGroup, schedule_name: str) -> Result.IntResult: + _schedule_setting_id = f'ScheduleMsg_{schedule_name}' + result = await group.setting_del(setting_name=_schedule_setting_id) + return result + + +async def add_scheduler( + group: DBBotGroup, + schedule_name: str, + mode: str, + hour: int, + minute: int, + repeat: str, + message: str): + global scheduler + _schedule_setting_id = f'ScheduleMsg_{group.self_bot.self_qq}_{schedule_name}' + self_bot: Bot = nonebot.get_bots().get(str(group.self_bot.self_qq), None) + if not self_bot: + raise ValueError('Can not get Bot') + + async def _scheduler_handle(): + await self_bot.send_group_msg(group_id=group.group_id, message=Message(f'【定时消息】\n{"="*12}\n{message}')) + + if mode == 'cron': + if repeat == 'all': + scheduler.add_job( + _scheduler_handle, + 'cron', + hour=hour, + minute=minute, + id=_schedule_setting_id, + coalesce=True, + misfire_grace_time=10 + ) + else: + scheduler.add_job( + _scheduler_handle, + 'cron', + day_of_week=repeat, + hour=hour, + minute=minute, + id=_schedule_setting_id, + coalesce=True, + misfire_grace_time=10 + ) + elif mode == 'interval': + if hour == 0 and minute != 0: + scheduler.add_job( + _scheduler_handle, + 'interval', + minutes=minute, + id=_schedule_setting_id, + coalesce=True, + misfire_grace_time=10 + ) + elif minute == 0: + scheduler.add_job( + _scheduler_handle, + 'interval', + hours=hour, + id=_schedule_setting_id, + coalesce=True, + misfire_grace_time=10 + ) + else: + scheduler.add_job( + _scheduler_handle, + 'interval', + hours=hour, + minutes=minute, + id=_schedule_setting_id, + coalesce=True, + misfire_grace_time=10 + ) + else: + raise ValueError(f'Unknown mode {mode}') + + +async def remove_scheduler(group: DBBotGroup, schedule_name: str): + global scheduler + _schedule_setting_id = f'ScheduleMsg_{group.self_bot.self_qq}_{schedule_name}' + scheduler.remove_job(_schedule_setting_id) + + +# Bot 连接时初始化其消息任务 +@driver.on_bot_connect +async def init_bot_message_schedule(bot: Bot): + self_bot = DBBot(self_qq=int(bot.self_id)) + group_list_result = await DBBotGroup.list_exist_bot_groups(self_bot=self_bot) + if group_list_result.error: + logger.error(f'Init bot message schedule failed, get bot group list failed: {group_list_result.info}') + for group in group_list_result.result: + _bot_group = DBBotGroup(group_id=group, self_bot=self_bot) + schedule_result = await list_db_group_schedule_message(group=_bot_group) + if schedule_result.error: + logger.error(f'Error occurred in init bot message schedule, ' + f'get group {_bot_group.group_id} message schedule list failed: {schedule_result.info}') + continue + for _name, _mode, _time, _message in schedule_result.result: + _name = re.sub(r'^ScheduleMsg_', '', str(_name)) + _repeat, _hour, _minute = [x for x in str(_time).split(':', maxsplit=3)] + _hour = int(_hour) + _minute = int(_minute) + try: + await add_scheduler(group=_bot_group, schedule_name=_name, + mode=_mode, hour=_hour, minute=_minute, repeat=_repeat, message=_message) + except Exception as e: + logger.error(f'Init bot message schedule failed, ' + f'为群组: {_bot_group.group_id} 添加群组定时消息任务失败, 添加计划任务时发生错误: {repr(e)}') + continue + + +# Bot 断开连接时移除其消息任务 +@driver.on_bot_disconnect +async def remove_bot_message_schedule(bot: Bot): + self_bot = DBBot(self_qq=int(bot.self_id)) + group_list_result = await DBBotGroup.list_exist_bot_groups(self_bot=self_bot) + if group_list_result.error: + logger.error(f'Remove bot message schedule failed, get bot group list failed: {group_list_result.info}') + for group in group_list_result.result: + _bot_group = DBBotGroup(group_id=group, self_bot=self_bot) + schedule_result = await list_db_group_schedule_message(group=_bot_group) + if schedule_result.error: + logger.error(f'Error occurred in remove bot message schedule, ' + f'get group {_bot_group.group_id} message schedule list failed: {schedule_result.info}') + continue + for _name, _mode, _time, _message in schedule_result.result: + _repeat, _hour, _minute = [x for x in str(_time).split(':', maxsplit=3)] + _hour = int(_hour) + _minute = int(_minute) + try: + await remove_scheduler(group=_bot_group, schedule_name=_name) + except Exception as e: + logger.error(f'Remove bot message schedule failed, ' + f'移除群组: {_bot_group.group_id} 定时消息任务失败, 移除计划任务时发生错误: {repr(e)}') + continue diff --git a/omega_miya/plugins/search_anime/__init__.py b/omega_miya/plugins/search_anime/__init__.py index 59a9d318..96c83289 100644 --- a/omega_miya/plugins/search_anime/__init__.py +++ b/omega_miya/plugins/search_anime/__init__.py @@ -1,12 +1,11 @@ -import re from nonebot import on_command, export, logger from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND from nonebot.adapters.cqhttp import MessageSegment, Message -from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state -from .utils import get_identify_result, pic_2_base64 +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PicEncoder +from .utils import get_identify_result # Custom plugin usage text @@ -56,22 +55,38 @@ async def parse(bot: Bot, event: MessageEvent, state: T_State): args = str(event.get_message()).strip().split() if not args: await search_anime.reject('你似乎没有发送有效的消息呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] if state[state["_current_key"]] == '取消': await search_anime.finish('操作已取消') + for msg_seg in event.message: + if msg_seg.type == 'image': + state[state["_current_key"]] = msg_seg.data.get('url') + return + @search_anime.handle() async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): - # 响应引用回复图片 + # 提取图片链接, 默认只取消息中的第一张图 + img_url = None if event.reply: - img_url = str(event.reply.message).strip() - if re.match(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=.+?])$', img_url): - state['image_url'] = img_url + for msg_seg in event.reply.message: + if msg_seg.type == 'image': + img_url = msg_seg.data.get('url') + break + else: + for msg_seg in event.message: + if msg_seg.type == 'image': + img_url = msg_seg.data.get('url') + break + if img_url: + state['image_url'] = img_url + return args = str(event.get_plaintext()).strip().lower().split() if args: - await search_anime.finish('该命令不支持参数QAQ') + await search_anime.finish('你发送的好像不是图片呢QAQ') @search_anime.got('image_url', prompt='请发送你想要识别的番剧截图:') @@ -82,53 +97,42 @@ async def handle_draw(bot: Bot, event: MessageEvent, state: T_State): group_id = 'Private event' image_url = state['image_url'] - if not re.match(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=.+?])$', image_url): - await search_anime.reject('你发送的似乎不是图片呢, 请重新发送, 取消命令请发送【取消】:') - - # 提取图片url - image_url = re.sub(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=)', '', image_url) - image_url = re.sub(r'(])$', '', image_url) await search_anime.send('获取识别结果中, 请稍后~') - res = await get_identify_result(img_url=image_url) if not res.success(): - logger.info(f"{group_id} / {event.user_id} search_anime failed: {res.info}") + logger.warning(f"{group_id} / {event.user_id} search_anime failed: {res.info}") await search_anime.finish('发生了意外的错误QAQ, 请稍后再试') if not res.result: - logger.info(f"{group_id} / {event.user_id} 使用了search_anime, 但没有找到相似的番剧") + logger.warning(f"{group_id} / {event.user_id} 使用了search_anime, 但没有找到相似的番剧") await search_anime.finish('没有找到与截图相似度足够高的番剧QAQ') for item in res.result: try: - raw_at = item.get('raw_at') - at = item.get('at') - anilist_id = item.get('anilist_id') - anime = item.get('anime') - episode = item.get('episode') - tokenthumb = item.get('tokenthumb') filename = item.get('filename') + episode = item.get('episode') + from_ = item.get('from') + to = item.get('to') similarity = item.get('similarity') + image = item.get('image') title_native = item.get('title_native') title_chinese = item.get('title_chinese') is_adult = item.get('is_adult') - thumb_img_url = f'https://trace.moe/thumbnail.php?' \ - f'anilist_id={anilist_id}&file={filename}&t={raw_at}&token={tokenthumb}' + img_result = await PicEncoder(pic_url=image).get_file(folder_flag='search_anime') - img_b64 = await pic_2_base64(thumb_img_url) - if not img_b64.success(): - msg = f"识别结果: {anime}\n\n原始名称:【{title_native}】\n中文名称:【{title_chinese}】\n" \ - f"相似度: {int(similarity)}\n\n来源文件: {filename}\nEpisode: 【{episode}】\n" \ - f"截图时间位置: {at}\n绅士: {is_adult}" + if img_result.error: + msg = f"识别结果:\n\n原始名称:【{title_native}】\n中文名称:【{title_chinese}】\n" \ + f"相似度: {int(similarity*100)}\n\n来源文件: {filename}\n集数: 【{episode}】\n" \ + f"预览图时间位置: {from_} - {to}\n绅士: {is_adult}" await search_anime.send(msg) else: - img_seg = MessageSegment.image(img_b64.result) - msg = f"识别结果: {anime}\n\n原始名称:【{title_native}】\n中文名称:【{title_chinese}】\n" \ - f"相似度: {int(similarity)}\n\n来源文件: {filename}\nEpisode: 【{episode}】\n" \ - f"截图时间位置: {at}\n绅士: {is_adult}\n{img_seg}" - await search_anime.send(Message(msg)) + img_seg = MessageSegment.image(img_result.result) + msg = f"识别结果:\n\n原始名称:【{title_native}】\n中文名称:【{title_chinese}】\n" \ + f"相似度: {int(similarity*100)}\n\n来源文件: {filename}\n集数: 【{episode}】\n" \ + f"预览图时间位置: {from_} - {to}\n绅士: {is_adult}" + await search_anime.send(Message(msg).append(img_seg)) except Exception as e: logger.error(f"{group_id} / {event.user_id} 使用命令search_anime时发生了错误: {repr(e)}") continue diff --git a/omega_miya/plugins/search_anime/utils.py b/omega_miya/plugins/search_anime/utils.py index b0d6d23b..437a30e1 100644 --- a/omega_miya/plugins/search_anime/utils.py +++ b/omega_miya/plugins/search_anime/utils.py @@ -1,32 +1,39 @@ -import datetime from nonebot import logger -from omega_miya.utils.Omega_plugin_utils import HttpFetcher, PicEncoder +from omega_miya.utils.Omega_plugin_utils import HttpFetcher from omega_miya.utils.Omega_Base import Result -API_URL = 'https://trace.moe/api/search' +API_URL = 'https://api.trace.moe/search' +# ANILIST_API_URL = 'https://graphql.anilist.co' # Anilist API +ANILIST_API_URL = 'https://trace.moe/anilist/' # 中文 Anilist API +ANILIST_API_QUERY = ''' +query ($id: Int) { # Define which variables will be used in the query (id) + Media (id: $id, type: ANIME) { # Insert our variables into the query arguments (id) (type: ANIME is hard-coded in the query) + id # you must query the id field for it to search the translated database + title { + native # do not query chinese here, the official Anilist API doesn't recognize + romaji + english + } + isAdult + synonyms # chinese titles will always be merged into this array + } +} +''' -HEADERS = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/89.0.4389.114 Safari/537.36'} - - -# 图片转base64 -async def pic_2_base64(url: str) -> Result.TextResult: - fetcher = HttpFetcher(timeout=10, flag='search_anime_get_image', headers=HEADERS) - bytes_result = await fetcher.get_bytes(url=url) - if bytes_result.error: - return Result.TextResult(error=True, info='Image download failed', result='') - - encode_result = PicEncoder.bytes_to_b64(image=bytes_result.result) - - if encode_result.success(): - return Result.TextResult(error=False, info='Success', result=encode_result.result) - else: - return Result.TextResult(error=True, info=encode_result.info, result='') +HEADERS = {'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,' + 'application/signed-exchange;v=b3;q=0.9', + 'accept-encoding': 'gzip, deflate', + 'accept-language': 'zh-CN,zh;q=0.9', + 'cache-control': 'max-age=0', + 'dnt': '1', + 'upgrade-insecure-requests': '1', + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/90.0.4430.212 Safari/537.36'} # 获取识别结果 -async def get_identify_result(img_url: str) -> Result.ListResult: +async def get_identify_result(img_url: str, *, sensitivity: float = 0.8) -> Result.ListResult: fetcher = HttpFetcher(timeout=10, flag='search_anime', headers=HEADERS) payload = {'url': img_url} @@ -34,27 +41,34 @@ async def get_identify_result(img_url: str) -> Result.ListResult: if not result_json.success(): return Result.ListResult(error=True, info=result_json.info, result=[]) - _res = result_json.result - if not _res.get('docs'): - return Result.ListResult(error=True, info='no result found', result=[]) - _result = [] - for item in _res.get('docs'): + for item in result_json.result.get('result'): try: - if item.get('similarity') < 0.85: + if item.get('similarity') < sensitivity: continue + anilist = item.get('anilist') + # 获取番剧信息 + payload = {'query': ANILIST_API_QUERY, 'variables': {'id': anilist}} + anilist_result = await fetcher.post_json(url=ANILIST_API_URL, json=payload) + if anilist_result.error: + raise Exception(anilist_result.info) + + title_native = anilist_result.result['data']['Media']['title']['native'] + title_chinese = anilist_result.result['data']['Media']['title']['chinese'] + is_adult = anilist_result.result['data']['Media']['isAdult'] + _result.append({ - 'raw_at': item.get('at'), - 'at': str(datetime.timedelta(seconds=item.get('at'))), - 'anilist_id': item.get('anilist_id'), - 'anime': item.get('anime'), - 'episode': item.get('episode'), - 'tokenthumb': item.get('tokenthumb'), + 'anilist': anilist, 'filename': item.get('filename'), + 'episode': item.get('episode'), + 'from': item.get('from'), + 'to': item.get('to'), 'similarity': item.get('similarity'), - 'title_native': item.get('title_native'), - 'title_chinese': item.get('title_chinese'), - 'is_adult': item.get('is_adult'), + 'video': item.get('video'), + 'image': item.get('image'), + 'title_native': title_native, + 'title_chinese': title_chinese, + 'is_adult': is_adult, }) except Exception as e: logger.warning(f'result parse failed: {repr(e)}, raw_json: {item}') diff --git a/omega_miya/plugins/search_image/__init__.py b/omega_miya/plugins/search_image/__init__.py index 0308ee1e..f07e6fcd 100644 --- a/omega_miya/plugins/search_image/__init__.py +++ b/omega_miya/plugins/search_image/__init__.py @@ -1,12 +1,22 @@ -import re -from nonebot import on_command, export, logger +import random +import asyncio +from nonebot import on_command, export, logger, get_driver from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND from nonebot.adapters.cqhttp import MessageSegment, Message -from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state -from .utils import pic_2_base64, get_saucenao_identify_result, get_ascii2d_identify_result +from omega_miya.utils.Omega_Base import DBBot +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PicEncoder, PermissionChecker +from omega_miya.utils.pixiv_utils import PixivIllust +from .utils import SEARCH_ENGINE, HEADERS +from .config import Config + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +ENABLE_SAUCENAO = plugin_config.enable_saucenao +ENABLE_IQDB = plugin_config.enable_iqdb +ENABLE_ASCII2D = plugin_config.enable_ascii2d # Custom plugin usage text __plugin_name__ = '识图' @@ -23,17 +33,21 @@ basic **Usage** -/识图''' +/识图 + +**Hidden Command** +/再来点''' # 声明本插件可配置的权限节点 __plugin_auth_node__ = [ - 'basic' + 'basic', + 'recommend_image', + 'allow_recommend_r18' ] # Init plugin export init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) - # 注册事件响应器 search_image = on_command( '识图', @@ -55,73 +69,125 @@ async def parse(bot: Bot, event: MessageEvent, state: T_State): args = str(event.get_message()).strip().split() if not args: await search_image.reject('你似乎没有发送有效的消息呢QAQ, 请重新发送:') + + if state["_current_key"] == 'using_engine': + if args[0] == '是': + return + else: + await search_image.finish('操作已取消') + state[state["_current_key"]] = args[0] if state[state["_current_key"]] == '取消': await search_image.finish('操作已取消') + for msg_seg in event.message: + if msg_seg.type == 'image': + state[state["_current_key"]] = msg_seg.data.get('url') + return + @search_image.handle() async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + # 识图引擎开关 + usable_engine = [] + if ENABLE_SAUCENAO: + usable_engine.append('saucenao') + if ENABLE_IQDB: + usable_engine.append('iqdb') + if ENABLE_ASCII2D: + usable_engine.append('ascii2d') + + state['using_engine'] = usable_engine.pop(0) if usable_engine else None + state['usable_engine'] = usable_engine + + # 提取图片链接, 默认只取消息中的第一张图 + img_url = None if event.reply: - img_url = str(event.reply.message).strip() - if re.match(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=.+?])$', img_url): - state['image_url'] = img_url + for msg_seg in event.reply.message: + if msg_seg.type == 'image': + img_url = msg_seg.data.get('url') + break + else: + for msg_seg in event.message: + if msg_seg.type == 'image': + img_url = msg_seg.data.get('url') + break + if img_url: + state['image_url'] = img_url + return args = str(event.get_plaintext()).strip().lower().split() if args: - await search_image.finish('该命令不支持参数QAQ') + await search_image.finish('你发送的好像不是图片呢QAQ') @search_image.got('image_url', prompt='请发送你想要识别的图片:') -async def handle_draw(bot: Bot, event: MessageEvent, state: T_State): +async def handle_got_image(bot: Bot, event: MessageEvent, state: T_State): + image_url = state['image_url'] + if not str(image_url).startswith('http'): + await search_image.finish('错误QAQ,你发送的不是有效的图片') + await search_image.send('获取识别结果中, 请稍后~') + + +@search_image.got('using_engine', prompt='使用识图引擎识图:') +async def handle_saucenao(bot: Bot, event: MessageEvent, state: T_State): + image_url = state['image_url'] + using_engine = state['using_engine'] + usable_engine = list(state['usable_engine']) + + # 获取识图结果 + search_engine = SEARCH_ENGINE.get(using_engine, None) + if using_engine and search_engine: + identify_result = await search_engine(image_url) + if identify_result.success() and identify_result.result: + # 有结果了, 继续执行接下来的结果解析handler + pass + else: + # 没有结果 + if identify_result.error: + logger.warning(f'{using_engine}引擎获取识别结果失败: {identify_result.info}') + if usable_engine: + # 还有可用的识图引擎 + next_using_engine = usable_engine.pop(0) + msg = f'{using_engine}引擎没有找到相似度足够高的图片,是否继续使用{next_using_engine}引擎识别图片?\n\n【是/否】' + state['using_engine'] = next_using_engine + state['usable_engine'] = usable_engine + await search_image.reject(msg) + else: + # 没有可用的识图引擎了 + logger.info(f'{event.user_id} 使用了searchimage所有的识图引擎, 但没有找到相似的图片') + await search_image.finish('没有找到相似度足够高的图片QAQ') + else: + logger.error(f'获取识图引擎异常, using_engine: {using_engine}') + await search_image.finish('发生了意外的错误QAQ, 请稍后再试或联系管理员') + return + + state['identify_result'] = identify_result.result + + +@search_image.handle() +async def handle_result(bot: Bot, event: MessageEvent, state: T_State): if isinstance(event, GroupMessageEvent): group_id = event.group_id else: group_id = 'Private event' - image_url = state['image_url'] - if not re.match(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=.+?])$', image_url): - await search_image.reject('你发送的似乎不是图片呢, 请重新发送, 取消命令请发送【取消】:') - - # 提取图片url - image_url = re.sub(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=)', '', image_url) - image_url = re.sub(r'(])$', '', image_url) - + identify_result = state['identify_result'] try: - has_error = False - await search_image.send('获取识别结果中, 请稍后~') - identify_result = [] - identify_saucenao_result = await get_saucenao_identify_result(url=image_url) - if identify_saucenao_result.success(): - identify_result.extend(identify_saucenao_result.result) - else: - has_error = True - - # saucenao 没有结果时再使用 ascii2d 进行搜索 - if not identify_result: - identify_ascii2d_result = await get_ascii2d_identify_result(url=image_url) - # 合并搜索结果 - if identify_ascii2d_result.success(): - identify_result.extend(identify_ascii2d_result.result) - else: - has_error = True if identify_result: for item in identify_result: try: - if type(item['ext_urls']) == list: - ext_urls = '' - for urls in item['ext_urls']: - ext_urls += f'{urls}\n' - ext_urls = ext_urls.strip() + if isinstance(item['ext_urls'], list): + ext_urls = '\n'.join(item['ext_urls']) else: - ext_urls = item['ext_urls'] - ext_urls = ext_urls.strip() - img_b64 = await pic_2_base64(item['thumbnail']) - if not img_b64.success(): + ext_urls = item['ext_urls'].strip() + img_result = await PicEncoder( + pic_url=item['thumbnail'], headers=HEADERS).get_file(folder_flag='search_image') + if img_result.error: msg = f"识别结果: {item['index_name']}\n\n相似度: {item['similarity']}\n资源链接: {ext_urls}" await search_image.send(msg) else: - img_seg = MessageSegment.image(img_b64.result) + img_seg = MessageSegment.image(img_result.result) msg = f"识别结果: {item['index_name']}\n\n相似度: {item['similarity']}\n资源链接: {ext_urls}\n{img_seg}" await search_image.send(Message(msg)) except Exception as e: @@ -129,10 +195,6 @@ async def handle_draw(bot: Bot, event: MessageEvent, state: T_State): continue logger.info(f"{group_id} / {event.user_id} 使用searchimage成功搜索了一张图片") return - elif not identify_result and has_error: - await search_image.send('识图过程中获取信息失败QAQ, 请重试一下吧') - logger.info(f"{group_id} / {event.user_id} 使用了searchimage, 但在识图过程中获取信息失败") - return else: await search_image.send('没有找到相似度足够高的图片QAQ') logger.info(f"{group_id} / {event.user_id} 使用了searchimage, 但没有找到相似的图片") @@ -141,3 +203,125 @@ async def handle_draw(bot: Bot, event: MessageEvent, state: T_State): await search_image.send('识图失败, 发生了意外的错误QAQ, 请稍后重试') logger.error(f"{group_id} / {event.user_id} 使用命令searchimage时发生了错误: {repr(e)}") return + + +# 注册事件响应器 +recommend_image = on_command( # 使用 pixiv api 的相关作品推荐功能查找相似作品 + '再来点', + aliases={'多来点', '相似作品', '类似作品'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='search_image_recommend_image', + command=True, + auth_node='recommend_image'), + permission=GROUP | PRIVATE_FRIEND, + priority=20, + block=True) + + +@recommend_image.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + # 从回复消息中捕获待匹配的图片信息 + # 只返回匹配到的第一个符合要求的链接或图片 + if event.reply: + # 首先筛查链接 + for msg_seg in event.reply.message: + if msg_seg.type == 'text': + text = msg_seg.data.get('text') + if pid := PixivIllust.parse_pid_from_url(text=text): + state['pid'] = pid + logger.debug(f"Recommend image | 已从消息段文本匹配到 pixiv url, pid: {pid}") + return + + # 若消息被分片可能导致链接被拆分 + raw_text = event.reply.dict().get('raw_message') + if pid := PixivIllust.parse_pid_from_url(text=raw_text): + state['pid'] = pid + logger.debug(f"Recommend image | 已从消息 raw 文本匹配到 pixiv url, pid: {pid}") + return + + # 没有发现则开始对图片进行识别, 为保证准确性只使用 saucenao api + for msg_seg in event.reply.message: + if msg_seg.type == 'image': + img_url = msg_seg.data.get('url') + saucenao_search_engine = SEARCH_ENGINE.get('saucenao') + identify_result = await saucenao_search_engine(img_url) + # 从识别结果中匹配图片 + for url_list in [x.get('ext_urls') for x in identify_result.result]: + for url in url_list: + if pid := PixivIllust.parse_pid_from_url(text=url): + state['pid'] = pid + logger.debug(f"Recommend image | 已从识别图片匹配到 pixiv url, pid: {pid}") + return + else: + logger.debug(f'Recommend image | 命令没有引用消息, 操作已取消') + await recommend_image.finish('没有引用需要查找的图片QAQ, 请使用本命令时直接回复相关消息') + + +@recommend_image.handle() +async def handle_illust_recommend(bot: Bot, event: GroupMessageEvent, state: T_State): + pid = state.get('pid') + if not pid: + logger.debug(f'Recommend image | 没有匹配到图片pid, 操作已取消') + await recommend_image.finish('没有匹配到相关图片QAQ, 请确认搜索的图片是在 Pixiv 上的作品') + + recommend_result = await PixivIllust(pid=pid).get_recommend(init_limit=36) + if recommend_result.error: + logger.warning(f'Recommend image | 获取相似作品信息失败, pid: {pid}, error: {recommend_result.info}') + await recommend_image.finish('获取相关作品信息失败QAQ, 原作品可能已经被删除') + + # 获取推荐作品的信息 + await recommend_image.send('稍等, 正在获取相似作品~') + pid_list = [x.get('id') for x in recommend_result.result.get('illusts') if x.get('illustType') == 0] + tasks = [PixivIllust(pid=x).get_illust_data() for x in pid_list] + recommend_illust_data_result = await asyncio.gather(*tasks) + + # 执行 r18 权限检查 + if isinstance(event, PrivateMessageEvent): + user_id = event.user_id + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))). \ + check_auth_node(auth_id=user_id, auth_type='user', auth_node='search_image.allow_recommend_r18') + elif isinstance(event, GroupMessageEvent): + group_id = event.group_id + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))). \ + check_auth_node(auth_id=group_id, auth_type='group', auth_node='search_image.allow_recommend_r18') + else: + auth_checker = 0 + + # 筛选推荐作品 筛选条件 收藏不少于2k 点赞数不少于收藏一半 点赞率大于百分之五 + if auth_checker == 1: + filtered_illust_data_result = [x for x in recommend_illust_data_result if ( + x.success() and + 2000 <= x.result.get('bookmark_count') <= 2 * x.result.get('like_count') and + x.result.get('view_count') <= 20 * x.result.get('like_count') + )] + else: + filtered_illust_data_result = [x for x in recommend_illust_data_result if ( + x.success() and + not x.result.get('is_r18') and + 2000 <= x.result.get('bookmark_count') <= 2 * x.result.get('like_count') and + x.result.get('view_count') <= 20 * x.result.get('like_count') + )] + + # 从筛选结果里面随机挑三个 + if len(filtered_illust_data_result) > 3: + illust_list = [PixivIllust(pid=x.result.get('pid')) for x in random.sample(filtered_illust_data_result, k=3)] + else: + illust_list = [PixivIllust(pid=x.result.get('pid')) for x in filtered_illust_data_result] + + if not illust_list: + logger.info(f'Recommend image | 筛选结果为0, 没有找到符合要求的相似作品') + await recommend_image.finish('没有找到符合要求的相似作品QAQ') + + # 直接下载图片 + tasks = [x.get_sending_msg() for x in illust_list] + illust_download_result = await asyncio.gather(*tasks) + + for img, info in [x.result for x in illust_download_result if x.success()]: + img_seg = MessageSegment.image(file=img) + try: + await recommend_image.send(Message(img_seg).append(info)) + except Exception as e: + logger.warning(f'Recommend image | 发送图片失败, error: {repr(e)}') + continue + logger.info(f'Recommend image | User: {event.user_id} 已获取相似图片') diff --git a/omega_miya/plugins/search_image/config.py b/omega_miya/plugins/search_image/config.py new file mode 100644 index 00000000..fe884edc --- /dev/null +++ b/omega_miya/plugins/search_image/config.py @@ -0,0 +1,22 @@ +""" +@Author : Ailitonia +@Date : 2021/06/16 22:53 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + # plugin custom config + # 识图引擎开关, 使用优先级saucenao>iqdb>ascii2d + enable_saucenao: bool = True + enable_iqdb: bool = True + enable_ascii2d: bool = True + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/search_image/utils.py b/omega_miya/plugins/search_image/utils.py index 4a7a5240..9a0b7b65 100644 --- a/omega_miya/plugins/search_image/utils.py +++ b/omega_miya/plugins/search_image/utils.py @@ -1,43 +1,32 @@ import re +from typing import Dict, Callable, Awaitable import nonebot from bs4 import BeautifulSoup from nonebot import logger -from omega_miya.utils.Omega_plugin_utils import HttpFetcher, PicEncoder +from omega_miya.utils.Omega_plugin_utils import HttpFetcher from omega_miya.utils.Omega_Base import Result global_config = nonebot.get_driver().config API_KEY = global_config.saucenao_api_key -API_URL = 'https://saucenao.com/search.php' +API_URL_SAUCENAO = 'https://saucenao.com/search.php' API_URL_ASCII2D = 'https://ascii2d.net/search/url/' +API_URL_IQDB = 'https://iqdb.org/' HEADERS = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' 'Chrome/89.0.4389.114 Safari/537.36'} - -# 图片转base64 -async def pic_2_base64(url: str) -> Result.TextResult: - fetcher = HttpFetcher(timeout=10, flag='search_image_get_image', headers=HEADERS) - bytes_result = await fetcher.get_bytes(url=url) - if bytes_result.error: - return Result.TextResult(error=True, info='Image download failed', result='') - - encode_result = PicEncoder.bytes_to_b64(image=bytes_result.result) - - if encode_result.success(): - return Result.TextResult(error=False, info='Success', result=encode_result.result) - else: - return Result.TextResult(error=True, info=encode_result.info, result='') +T_SearchEngine = Callable[[str], Awaitable[Result.DictListResult]] # 获取识别结果 Saucenao模块 -async def get_saucenao_identify_result(url: str) -> Result.ListResult: +async def get_saucenao_identify_result(url: str) -> Result.DictListResult: fetcher = HttpFetcher(timeout=10, flag='search_image_saucenao', headers=HEADERS) if not API_KEY: logger.opt(colors=True).warning(f'Saucenao API KEY未配置, 无法使用Saucenao API进行识图!') - return Result.ListResult(error=True, info='Saucenao API KEY未配置', result=[]) + return Result.DictListResult(error=True, info='Saucenao API KEY未配置', result=[]) __payload = {'output_type': 2, 'api_key': API_KEY, @@ -45,17 +34,17 @@ async def get_saucenao_identify_result(url: str) -> Result.ListResult: 'numres': 6, 'db': 999, 'url': url} - saucenao_result = await fetcher.get_json(url=API_URL, params=__payload) + saucenao_result = await fetcher.get_json(url=API_URL_SAUCENAO, params=__payload) if saucenao_result.error: logger.warning(f'get_saucenao_identify_result failed, Network error: {saucenao_result.info}') - return Result.ListResult(error=True, info=f'Network error: {saucenao_result.info}', result=[]) + return Result.DictListResult(error=True, info=f'Network error: {saucenao_result.info}', result=[]) __result_json = saucenao_result.result if __result_json['header']['status'] != 0: logger.error(f"get_saucenao_identify_result failed, DataSource error, " f"status code: {__result_json['header']['status']}") - return Result.ListResult( + return Result.DictListResult( error=True, info=f"DataSource error, status code: {__result_json['header']['status']}", result=[]) __result = [] @@ -71,23 +60,23 @@ async def get_saucenao_identify_result(url: str) -> Result.ListResult: except Exception as res_err: logger.warning(f"get_saucenao_identify_result failed: {repr(res_err)}, can not resolve results") continue - return Result.ListResult(error=False, info='Success', result=__result) + return Result.DictListResult(error=False, info='Success', result=__result) # 获取识别结果 ascii2d模块 -async def get_ascii2d_identify_result(url: str) -> Result.ListResult: +async def get_ascii2d_identify_result(url: str) -> Result.DictListResult: fetcher = HttpFetcher(timeout=10, flag='search_image_ascii2d', headers=HEADERS) search_url = f'{API_URL_ASCII2D}{url}' saucenao_redirects_result = await fetcher.get_text(url=search_url, allow_redirects=False) if saucenao_redirects_result.error: logger.error(f'get_ascii2d_identify_result failed: 获取识别结果url发生错误, 错误信息详见日志.') - return Result.ListResult(error=True, info=f'Get identify result url failed', result=[]) + return Result.DictListResult(error=True, info=f'Get identify result url failed', result=[]) ascii2d_color_url = saucenao_redirects_result.headers.get('Location') if not ascii2d_color_url: - logger.error(f'get_ascii2d_identify_result failed: 获取识别结果url发生错误, 可能被流量限制.') - return Result.ListResult(error=True, info=f'Get identify result url failed, may be limited', result=[]) + logger.error(f'get_ascii2d_identify_result failed: 获取识别结果url发生错误, 可能被流量限制, 或图片大小超过5Mb.') + return Result.DictListResult(error=True, info=f'Get identify result url failed, may be limited', result=[]) ascii2d_bovw_url = re.sub( r'https://ascii2d\.net/search/color/', r'https://ascii2d.net/search/bovw/', ascii2d_color_url) @@ -101,7 +90,7 @@ async def get_ascii2d_identify_result(url: str) -> Result.ListResult: pre_bs_list.append(bovw_res.result) if not pre_bs_list: logger.error(f'get_ascii2d_identify_result ERROR: 获取识别结果异常, 错误信息详见日志.') - return Result.ListResult(error=True, info=f'Get identify result data failed', result=[]) + return Result.DictListResult(error=True, info=f'Get identify result data failed', result=[]) __result = [] @@ -152,4 +141,111 @@ async def get_ascii2d_identify_result(url: str) -> Result.ListResult: except Exception as row_err: logger.warning(f'get_ascii2d_identify_result ERROR: {repr(row_err)}, 解搜索结果条目时发生错误.') continue - return Result.ListResult(error=False, info=f'Success', result=__result) + return Result.DictListResult(error=False, info=f'Success', result=__result) + + +# 获取识别结果 iqdb模块 +async def get_iqdb_identify_result(url: str) -> Result.DictListResult: + headers = HEADERS.copy().update({ + 'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,' + 'image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9', + 'accept-encoding': 'gzip, deflate', + 'accept-language': 'zh-CN,zh;q=0.9', + 'Cache-Control': 'max-age=0', + 'Connection': 'keep-alive', + 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundarycljlxd876c1ld4Zr', + 'dnt': '1', + 'Host': 'iqdb.org', + 'Origin': 'https://iqdb.org', + 'Referer': 'https://iqdb.org/', + 'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="91", "Chromium";v="91"', + 'sec-ch-ua-mobile': '?0', + 'Sec-Fetch-Dest': 'document', + 'Sec-Fetch-Mode': 'navigate', + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-User': '?1', + 'sec-gpc': '1', + 'Upgrade-Insecure-Requests': '1', + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' + 'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36' + }) + fetcher = HttpFetcher(timeout=30, flag='search_image_iqdb', headers=headers) + data = fetcher.FormData(boundary='----WebKitFormBoundarycljlxd876c1ld4Zr') + data.add_field(name='MAX_FILE_SIZE', value='') + for i in [1, 2, 3, 4, 5, 6, 11, 13]: + data.add_field(name='service[]', value=str(i)) + data.add_field(name='file', value=b'', content_type='application/octet-stream', filename='') + data.add_field(name='url', value=url) + iqdb_result = await fetcher.post_text(url=API_URL_IQDB, data=data) + + if iqdb_result.error or iqdb_result.status != 200: + logger.warning(f'get_iqdb_identify_result failed, 获取识别结果失败: {iqdb_result.status}, {iqdb_result.info}') + return Result.DictListResult(error=True, info=f'Get identify result failed: {iqdb_result.info}', result=[]) + + try: + gallery_soup = BeautifulSoup(iqdb_result.result, 'lxml') + # 搜索结果 + result_div = gallery_soup.find('div', {'id': 'pages', 'class': 'pages'}).children + # 从搜索结果中解析具体每一个结果 + result_list = [x.find_all('tr') for x in result_div if x.name == 'div'] + except Exception as page_err: + logger.warning(f'get_iqdb_identify_result failed: {repr(page_err)}, 解析结果页时发生错误.') + return Result.DictListResult(error=True, info=f'Parse identify result failed: {repr(page_err)}', result=[]) + + result = [] + for item in result_list: + try: + if item[0].get_text() == 'Best match': + # 第二行是匹配缩略图及链接 + urls = '\n'.join([str(x.find('a').get('href')).strip('/') for x in item if x.find('a')]) + img = item[1].find('img').get('src') + # 最后一行是相似度 + similarity = item[-1].get_text() + result.append({ + 'similarity': similarity, + 'thumbnail': f'https://iqdb.org{img}', + 'index_name': f'iqdb - Best match', + 'ext_urls': urls + }) + elif item[0].get_text() == 'Additional match': + # 第二行是匹配缩略图及链接 + urls = '\n'.join([str(x.find('a').get('href')).strip('/') for x in item if x.find('a')]) + img = item[1].find('img').get('src') + # 最后一行是相似度 + similarity = item[-1].get_text() + result.append({ + 'similarity': similarity, + 'thumbnail': f'https://iqdb.org{img}', + 'index_name': f'iqdb - Additional match', + 'ext_urls': urls + }) + elif item[0].get_text() == 'Possible match': + # # 第二行是匹配缩略图及链接 + # urls = '\n'.join([str(x.find('a').get('href')).strip('/') for x in item if x.find('a')]) + # img = item[1].find('img').get('src') + # # 最后一行是相似度 + # similarity = item[-1].get_text() + # result.append({ + # 'similarity': similarity, + # 'thumbnail': f'https://iqdb.org{img}', + # 'index_name': f'iqdb - Possible match', + # 'ext_urls': urls + # }) + pass + except Exception as parse_err: + logger.warning(f'get_iqdb_identify_result parse error: {repr(parse_err)}, 解搜索结果条目时发生错误..') + return Result.DictListResult(error=False, info='Success', result=result) + + +# 可用的识图api +SEARCH_ENGINE: Dict[str, T_SearchEngine] = { + 'saucenao': get_saucenao_identify_result, + 'iqdb': get_iqdb_identify_result, + 'ascii2d': get_ascii2d_identify_result +} + + +__all__ = [ + 'HEADERS', + 'SEARCH_ENGINE' +] diff --git a/omega_miya/plugins/setu/__init__.py b/omega_miya/plugins/setu/__init__.py index 8bcf5ff0..362d352a 100644 --- a/omega_miya/plugins/setu/__init__.py +++ b/omega_miya/plugins/setu/__init__.py @@ -1,5 +1,8 @@ +import os +import re import asyncio -from nonebot import CommandGroup, on_command, export, logger +import aiofiles +from nonebot import CommandGroup, on_command, export, get_driver, logger from nonebot.rule import to_me from nonebot.permission import SUPERUSER from nonebot.typing import T_State @@ -7,10 +10,25 @@ from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.cqhttp.permission import GROUP, PRIVATE_FRIEND from nonebot.adapters.cqhttp import MessageSegment -from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PluginCoolDown -from omega_miya.utils.Omega_Base import DBPixivillust +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PluginCoolDown, PermissionChecker +from omega_miya.utils.Omega_plugin_utils import PicEncoder, PicEffector, MsgSender, ProcessUtils +from omega_miya.utils.Omega_Base import DBBot, DBPixivillust from omega_miya.utils.pixiv_utils import PixivIllust from .utils import add_illust +from .config import Config + + +__global_config = get_driver().config +plugin_config = Config(**__global_config.dict()) +IMAGE_NUM_LIMIT = plugin_config.image_num_limit +ENABLE_NODE_CUSTOM = plugin_config.enable_node_custom +ENABLE_MOE_FLASH = plugin_config.enable_moe_flash +ENABLE_SETU_FLASH = plugin_config.enable_setu_flash +ENABLE_SETU_GAUSSIAN_BLUR = plugin_config.enable_setu_gaussian_blur +ENABLE_SETU_GAUSSIAN_NOISE = plugin_config.enable_setu_gaussian_noise +AUTO_RECALL_TIME = plugin_config.auto_recall_time +ENABLE_MOE_AUTO_RECALL = plugin_config.enable_moe_auto_recall +ENABLE_SETU_AUTO_RECALL = plugin_config.enable_setu_auto_recall # Custom plugin usage text @@ -27,6 +45,7 @@ **AuthNode** setu moepic +allow_r18 **CoolDown** 群组共享冷却时间 @@ -40,18 +59,20 @@ **SuperUser Only** /图库统计 +/图库查询 [*keywords] /导入图库''' # 声明本插件可配置的权限节点 __plugin_auth_node__ = [ PluginCoolDown.skip_auth_node, 'setu', + 'allow_r18', 'moepic' ] # 声明本插件的冷却时间配置 __plugin_cool_down__ = [ - PluginCoolDown(PluginCoolDown.user_type, 1), + PluginCoolDown(PluginCoolDown.user_type, 2), PluginCoolDown(PluginCoolDown.group_type, 1) ] @@ -79,7 +100,7 @@ async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): # 处理r18 state['nsfw_tag'] = 1 for tag in args.copy(): - if tag in ['r18', 'R18', 'r-18', 'R-18']: + if re.match(r'[Rr]-?18[Gg]?', tag): args.remove(tag) state['nsfw_tag'] = 2 state['tags'] = list(args) @@ -95,12 +116,29 @@ async def handle_setu(bot: Bot, event: MessageEvent, state: T_State): nsfw_tag = state['nsfw_tag'] tags = state['tags'] + # 处理R18权限 + if nsfw_tag > 1: + if isinstance(event, PrivateMessageEvent): + user_id = event.user_id + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))). \ + check_auth_node(auth_id=user_id, auth_type='user', auth_node='setu.allow_r18') + elif isinstance(event, GroupMessageEvent): + group_id = event.group_id + auth_checker = await PermissionChecker(self_bot=DBBot(self_qq=int(bot.self_id))). \ + check_auth_node(auth_id=group_id, auth_type='group', auth_node='setu.allow_r18') + else: + auth_checker = 0 + + if auth_checker != 1: + logger.warning(f"User: {event.user_id} 请求涩图被拒绝, 没有 allow_r18 权限") + await setu.finish('R18禁止! 不准开车车!') + if tags: - pid_res = await DBPixivillust.list_illust(keywords=tags, num=3, nsfw_tag=nsfw_tag) + pid_res = await DBPixivillust.list_illust(keywords=tags, num=IMAGE_NUM_LIMIT, nsfw_tag=nsfw_tag) pid_list = pid_res.result else: # 没有tag则随机获取 - pid_res = await DBPixivillust.rand_illust(num=3, nsfw_tag=nsfw_tag) + pid_res = await DBPixivillust.rand_illust(num=IMAGE_NUM_LIMIT, nsfw_tag=nsfw_tag) pid_list = pid_res.result if not pid_list: @@ -110,29 +148,75 @@ async def handle_setu(bot: Bot, event: MessageEvent, state: T_State): # 处理article中图片内容 tasks = [] for pid in pid_list: - tasks.append(PixivIllust(pid=pid).pic_2_base64()) + tasks.append(PixivIllust(pid=pid).load_illust_pic()) p_res = await asyncio.gather(*tasks) + + # 处理图片消息段, 之后再根据ENABLE_NODE_CUSTOM确定消息发送方式 fault_count = 0 + image_seg_list = [] for image_res in p_res: try: - if not image_res.success(): + if image_res.error: fault_count += 1 logger.warning(f'图片下载失败, error: {image_res.info}') continue + if ENABLE_SETU_GAUSSIAN_NOISE: + image_res = await PicEffector(image=image_res.result).gaussian_noise(sigma=16) + if image_res.error: + fault_count += 1 + logger.warning(f'处理图片高斯噪声处理失败, error: {image_res.info}') + continue + if ENABLE_SETU_GAUSSIAN_BLUR: + image_res = await PicEffector(image=image_res.result).gaussian_blur(radius=4) + if image_res.error: + fault_count += 1 + logger.warning(f'处理图片高斯模糊失败, error: {image_res.info}') + continue + image_res = await PicEncoder.bytes_to_file(image=image_res.result, folder_flag='setu') + if image_res.error: + fault_count += 1 + logger.warning(f'图片转换失败, error: {image_res.info}') + continue else: - img_seg = MessageSegment.image(image_res.result) - # 发送图片 - await setu.send(img_seg) + if ENABLE_SETU_FLASH: + image_seg_list.append(MessageSegment.image(image_res.result, type_='flash')) + else: + image_seg_list.append(MessageSegment.image(image_res.result)) except Exception as e: - logger.warning(f"图片发送失败, {group_id} / {event.user_id}. error: {repr(e)}") + logger.warning(f'预处理图片失败: {repr(e)}') continue + sent_msg_ids = [] + # 根据ENABLE_NODE_CUSTOM处理消息发送 + if ENABLE_NODE_CUSTOM and isinstance(event, GroupMessageEvent): + msg_sender = MsgSender(bot=bot, log_flag='Setu') + await msg_sender.safe_send_group_node_custom(group_id=event.group_id, message_list=image_seg_list) + else: + for msg_seg in image_seg_list: + try: + sent_msg_id = await setu.send(msg_seg) + sent_msg_ids.append(sent_msg_id.get('message_id') if isinstance(sent_msg_id, dict) else None) + except Exception as e: + logger.warning(f'图片发送失败, {group_id} / {event.user_id}. error: {repr(e)}') + if fault_count == len(pid_list): logger.info(f"{group_id} / {event.user_id} 没能看到他/她想要的涩图, 图片下载失败, {pid_list}") await setu.finish('似乎出现了网络问题, 所有的图片都下载失败了QAQ') else: logger.info(f"{group_id} / {event.user_id} 找到了他/她想要的涩图, {pid_list}") + if ENABLE_SETU_AUTO_RECALL: + logger.info(f"{group_id} / {event.user_id} 撤回已发送涩图...") + await asyncio.sleep(AUTO_RECALL_TIME) + for msg_id in sent_msg_ids: + if not msg_id: + continue + try: + await bot.delete_msg(message_id=msg_id) + except Exception as e: + logger.warning(f'撤回图片失败, {group_id} / {event.user_id}, msg_id: {msg_id}. error: {repr(e)}') + continue + # 注册事件响应器 moepic = Setu.command( @@ -151,7 +235,7 @@ async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): args = set(str(event.get_plaintext()).strip().split()) # 排除r18 for tag in args.copy(): - if tag in ['r18', 'R18', 'r-18', 'R-18']: + if re.match(r'[Rr]-?18[Gg]?', tag): args.remove(tag) state['tags'] = list(args) @@ -166,11 +250,11 @@ async def handle_moepic(bot: Bot, event: MessageEvent, state: T_State): tags = state['tags'] if tags: - pid_res = await DBPixivillust.list_illust(keywords=tags, num=3, nsfw_tag=0) + pid_res = await DBPixivillust.list_illust(keywords=tags, num=IMAGE_NUM_LIMIT, nsfw_tag=0) pid_list = pid_res.result else: # 没有tag则随机获取 - pid_res = await DBPixivillust.rand_illust(num=3, nsfw_tag=0) + pid_res = await DBPixivillust.rand_illust(num=IMAGE_NUM_LIMIT, nsfw_tag=0) pid_list = pid_res.result if not pid_list: @@ -181,29 +265,63 @@ async def handle_moepic(bot: Bot, event: MessageEvent, state: T_State): # 处理article中图片内容 tasks = [] for pid in pid_list: - tasks.append(PixivIllust(pid=pid).pic_2_base64()) + tasks.append(PixivIllust(pid=pid).load_illust_pic()) p_res = await asyncio.gather(*tasks) + + # 处理图片消息段, 之后再根据ENABLE_NODE_CUSTOM确定消息发送方式 fault_count = 0 + image_seg_list = [] for image_res in p_res: try: - if not image_res.success(): + if image_res.error: fault_count += 1 logger.warning(f'图片下载失败, error: {image_res.info}') continue + image_res = await PicEncoder.bytes_to_file(image=image_res.result, folder_flag='moepic') + if image_res.error: + fault_count += 1 + logger.warning(f'图片转换失败, error: {image_res.info}') + continue else: - img_seg = MessageSegment.image(image_res.result) - # 发送图片 - await moepic.send(img_seg) + if ENABLE_MOE_FLASH: + image_seg_list.append(MessageSegment.image(image_res.result, type_='flash')) + else: + image_seg_list.append(MessageSegment.image(image_res.result)) except Exception as e: - logger.warning(f"图片发送失败, {group_id} / {event.user_id}. error: {repr(e)}") + logger.warning(f'预处理图片失败: {repr(e)}') continue + sent_msg_ids = [] + # 根据ENABLE_NODE_CUSTOM处理消息发送 + if ENABLE_NODE_CUSTOM and isinstance(event, GroupMessageEvent): + msg_sender = MsgSender(bot=bot, log_flag='Moepic') + await msg_sender.safe_send_group_node_custom(group_id=event.group_id, message_list=image_seg_list) + else: + for msg_seg in image_seg_list: + try: + sent_msg_id = await moepic.send(msg_seg) + sent_msg_ids.append(sent_msg_id.get('message_id') if isinstance(sent_msg_id, dict) else None) + except Exception as e: + logger.warning(f'图片发送失败, {group_id} / {event.user_id}. error: {repr(e)}') + if fault_count == len(pid_list): logger.info(f"{group_id} / {event.user_id} 没能看到他/她想要的萌图, 图片下载失败, {pid_list}") await moepic.finish('似乎出现了网络问题, 所有的图片都下载失败了QAQ') else: logger.info(f"{group_id} / {event.user_id} 找到了他/她想要的萌图, {pid_list}") + if ENABLE_MOE_AUTO_RECALL: + logger.info(f"{group_id} / {event.user_id} 撤回已发送萌图...") + await asyncio.sleep(AUTO_RECALL_TIME) + for msg_id in sent_msg_ids: + if not msg_id: + continue + try: + await bot.delete_msg(message_id=msg_id) + except Exception as e: + logger.warning(f'撤回图片失败, {group_id} / {event.user_id}, msg_id: {msg_id}. error: {repr(e)}') + continue + # 注册事件响应器 setu_stat = on_command('图库统计', rule=to_me(), permission=SUPERUSER, priority=20, block=True) @@ -212,16 +330,45 @@ async def handle_moepic(bot: Bot, event: MessageEvent, state: T_State): @setu_stat.handle() async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): status_res = await DBPixivillust.status() + if status_res.error: + logger.error(f'{event.user_id} 执行图库统计失败, {status_res.info}') + await setu_stat.finish('查询失败了QAQ, 请稍后再试') + msg = f"本地数据库统计:\n\n" \ f"全部: {status_res.result.get('total')}\n" \ f"萌图: {status_res.result.get('moe')}\n" \ f"涩图: {status_res.result.get('setu')}\n" \ f"R18: {status_res.result.get('r18')}" + logger.info(f'{event.user_id} 执行图库统计成功') await setu_stat.finish(msg) # 注册事件响应器 -setu_import = on_command('导入图库', rule=to_me(), permission=SUPERUSER, priority=20, block=True) +setu_count = on_command('图库查询', aliases={'查询图库'}, rule=to_me(), permission=SUPERUSER, priority=20, block=True) + + +@setu_count.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().split() + if not args: + logger.info(f'{event.user_id} 执行图库查询被取消, 未指定查询关键字') + await setu_count.finish('无查询关键字QAQ, 查询取消') + + count_res = await DBPixivillust.count_keywords(keywords=args) + if count_res.error: + logger.error(f'{event.user_id} 执行图库查询失败, {count_res.info}') + await setu_count.finish('无查询关键字QAQ, 查询取消') + + msg = f"查询关键字 {'/'.join(args)} 结果:\n\n" \ + f"全部: {count_res.result.get('total')}\n" \ + f"萌图: {count_res.result.get('moe')}\n" \ + f"涩图: {count_res.result.get('setu')}\n" \ + f"R18: {count_res.result.get('r18')}" + await setu_count.finish(msg) + + +# 注册事件响应器 +setu_import = on_command('导入图库', aliases={'图库导入'}, rule=to_me(), permission=SUPERUSER, priority=20, block=True) # 修改默认参数处理 @@ -254,11 +401,10 @@ async def handle_setu_import(bot: Bot, event: MessageEvent, state: T_State): if mode == 'moe': nsfw_tag = 0 + force_tag = True else: nsfw_tag = 1 - - import os - import re + force_tag = False # 文件操作 import_pid_file = os.path.join(os.path.dirname(__file__), 'import_pid.txt') @@ -268,8 +414,8 @@ async def handle_setu_import(bot: Bot, event: MessageEvent, state: T_State): pid_list = [] try: - with open(import_pid_file) as f: - lines = f.readlines() + async with aiofiles.open(import_pid_file, 'r') as f: + lines = await f.readlines() for line in lines: if not re.match(r'^[0-9]+$', line): logger.debug(f'setu_import: 导入列表中有非数字字符: {line}') @@ -279,37 +425,17 @@ async def handle_setu_import(bot: Bot, event: MessageEvent, state: T_State): logger.error(f'setu_import: 读取导入列表失败, error: {repr(e)}') await setu_import.finish('错误: 读取导入列表失败QAQ') - await setu_import.send('已读取导入文件列表, 开始获取作品信息~') - # 对列表去重 pid_list = list(set(pid_list)) - - # 导入操作 + pid_list.sort() all_count = len(pid_list) - success_count = 0 - # 全部一起并发api撑不住, 做适当切分 - # 每个切片数量 - seg_n = 10 - pid_seg_list = [] - for i in range(0, all_count, seg_n): - pid_seg_list.append(pid_list[i:i + seg_n]) - # 每个切片打包一个任务 - seg_len = len(pid_seg_list) - process_rate = 0 - for seg_list in pid_seg_list: - tasks = [] - for pid in seg_list: - tasks.append(add_illust(pid=pid, nsfw_tag=nsfw_tag)) - # 进行异步处理 - _res = await asyncio.gather(*tasks) - # 对结果进行计数 - for item in _res: - if item.success(): - success_count += 1 - # 显示进度 - process_rate += 1 - if process_rate % 10 == 0: - await setu_import.send(f'导入操作中,已完成: {process_rate}/{seg_len}') + await setu_import.send('已读取导入文件列表, 开始获取作品信息~') + logger.info(f'setu_import: 读取导入文件列表完成, 总计: {all_count}, 开始导入...') + # 开始导入操作 + # 全部一起并发网络撑不住, 做适当切分 + tasks = [add_illust(pid=pid, nsfw_tag=nsfw_tag, force_tag=force_tag) for pid in pid_list] + _res = await ProcessUtils.fragment_process(tasks=tasks, fragment_size=50, log_flag='Setu Import') + success_count = len([x for x in _res if x.success()]) logger.info(f'setu_import: 导入操作已完成, 成功: {success_count}, 总计: {all_count}') await setu_import.send(f'导入操作已完成, 成功: {success_count}, 总计: {all_count}') diff --git a/omega_miya/plugins/setu/config.py b/omega_miya/plugins/setu/config.py new file mode 100644 index 00000000..5005c9ee --- /dev/null +++ b/omega_miya/plugins/setu/config.py @@ -0,0 +1,40 @@ +""" +@Author : Ailitonia +@Date : 2021/06/03 22:05 +@FileName : config.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from pydantic import BaseSettings + + +class Config(BaseSettings): + # plugin custom config + # 每次查询的图片数量限制 + image_num_limit: int = 3 + # 启用使用群组转发自定义消息节点的模式发送信息, 仅在群组消息中生效 + # 发送速度受限于网络上传带宽, 有可能导致超时或发送失败, 请酌情启用 + enable_node_custom: bool = False + + # 实验性功能, 可能引发协议端奇怪的问题(指base64图片+闪照的组合) + # 以下选项产生的效果均在权限验证之后,并直接影响最终发送的图片 + # 启用使用闪照模式发送萌图, 仅影响"/来点萌图"命令 + enable_moe_flash: bool = False + # 启用使用闪照模式发送涩图, 仅影响"/来点涩图"命令 + enable_setu_flash: bool = False + # 启用使用高斯模糊提前处理待发送的涩图, 仅影响"/来点涩图"命令, 可与enable_setu_gaussian_noise一同使用, 可能会导致处理和发送图片时间提升 + enable_setu_gaussian_blur: bool = False + # 启用使用高斯噪声提前处理待发送的涩图, 仅影响"/来点涩图"命令, 可与enable_setu_gaussian_blur一同使用, 可能会导致处理发送图片时间提升 + enable_setu_gaussian_noise: bool = True + + # 启用发送图片后自动撤回, 默认撤回时间10秒 + # !如果启用了转发消息节点模式(enable_node_custom=True)则以下选项不会生效! + auto_recall_time: int = 25 + enable_moe_auto_recall: bool = False + enable_setu_auto_recall: bool = True + + class Config: + extra = "ignore" diff --git a/omega_miya/plugins/setu/utils.py b/omega_miya/plugins/setu/utils.py index e2960424..14a811ad 100644 --- a/omega_miya/plugins/setu/utils.py +++ b/omega_miya/plugins/setu/utils.py @@ -1,8 +1,9 @@ +from nonebot import logger from omega_miya.utils.Omega_Base import DBPixivillust, Result from omega_miya.utils.pixiv_utils import PixivIllust -async def add_illust(pid: int, nsfw_tag: int) -> Result.IntResult: +async def add_illust(pid: int, nsfw_tag: int, *, force_tag: bool = False) -> Result.IntResult: illust_result = await PixivIllust(pid=pid).get_illust_data() if illust_result.success(): @@ -11,14 +12,32 @@ async def add_illust(pid: int, nsfw_tag: int) -> Result.IntResult: uid = illust_data.get('uid') uname = illust_data.get('uname') url = illust_data.get('url') + width = illust_data.get('width') + height = illust_data.get('height') tags = illust_data.get('tags') is_r18 = illust_data.get('is_r18') + illust_pages = illust_data.get('illust_pages') if is_r18: nsfw_tag = 2 illust = DBPixivillust(pid=pid) - _res = await illust.add(uid=uid, title=title, uname=uname, nsfw_tag=nsfw_tag, tags=tags, url=url) - return _res + illust_add_result = await illust.add(uid=uid, title=title, uname=uname, nsfw_tag=nsfw_tag, + width=width, height=height, tags=tags, url=url, force_tag=force_tag) + if illust_add_result.error: + logger.error(f'Setu | Adding illust to database failed, pid: {pid}, error: {illust_add_result.info}') + return illust_add_result + + for page, urls in illust_pages.items(): + original = urls.get('original') + regular = urls.get('regular') + small = urls.get('small') + thumb_mini = urls.get('thumb_mini') + page_upgrade_result = await illust.upgrade_page( + page=page, original=original, regular=regular, small=small, thumb_mini=thumb_mini) + if page_upgrade_result.error: + logger.warning(f'Setu | upgrade illust page {page} failed: {page_upgrade_result.info}') + return illust_add_result else: + logger.error(f'Setu | Getting illust data failed, pid: {pid}, error: {illust_result.info}') return Result.IntResult(error=True, info=illust_result.info, result=-1) diff --git a/omega_miya/plugins/shindan_maker/__init__.py b/omega_miya/plugins/shindan_maker/__init__.py new file mode 100644 index 00000000..5a4de643 --- /dev/null +++ b/omega_miya/plugins/shindan_maker/__init__.py @@ -0,0 +1,201 @@ +""" +@Author : Ailitonia +@Date : 2021/06/28 21:41 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : shindan_maker 无聊的占卜插件 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import re +import datetime +from typing import Dict +from nonebot import MatcherGroup, export, logger +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, PluginCoolDown, OmegaRules +from .data_source import ShindanMaker + + +# Custom plugin usage text +__plugin_name__ = 'ShindanMaker' +__plugin_usage__ = r'''【ShindanMaker 占卜】 +使用ShindanMaker进行各种奇怪的占卜 +只能在群里使用 +就是要公开处刑! + +**Permission** +Command & Lv.30 +or AuthNode + +**AuthNode** +basic + +**Usage** +/ShindanMaker [占卜名称] [占卜对象名称]''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + PluginCoolDown.skip_auth_node, + 'basic' +] + +# # 声明本插件的冷却时间配置 +# __plugin_cool_down__ = [ +# PluginCoolDown(PluginCoolDown.user_type, 1), +# PluginCoolDown(PluginCoolDown.group_type, 1) +# ] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +# 缓存占卜名称与对应id +SHINDANMAKER_CACHE: Dict[str, int] = {} + + +# 注册事件响应器 +shindan_maker = MatcherGroup( + type='message', + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='shindan_maker', + command=True, + level=30, + auth_node='basic'), + permission=GROUP, + priority=20, + block=True) + + +shindan_maker_default = shindan_maker.on_command( + 'ShindanMaker', aliases={'占卜', 'shindanmaker', 'SHINDANMAKER', 'Shindan', 'shindan', 'SHINDAN'}) + + +# 修改默认参数处理 +@shindan_maker_default.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().split() + if not args: + await shindan_maker_default.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await shindan_maker_default.finish('操作已取消') + + +@shindan_maker_default.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().split() + state['id'] = 0 + if not args: + pass + elif args and len(args) == 1: + state['shindan_name'] = args[0] + elif args and len(args) == 2: + state['shindan_name'] = args[0] + state['input_name'] = args[1] + else: + await shindan_maker_default.finish('参数错误QAQ') + + # 特殊处理@人 + if len(event.message) >= 2: + if event.message[1].type == 'at': + at_qq = event.message[1].data.get('qq') + group_member_list = await bot.get_group_member_list(group_id=event.group_id) + nickname = [x for x in group_member_list if x.get('user_id') == int(at_qq)] + if nickname: + input_name = nickname[0].get('card') if nickname[0].get('card') else nickname[0].get('nickname') + if input_name: + state['input_name'] = input_name + + +@shindan_maker_default.got('shindan_name', prompt='你想做什么占卜呢?\n不知道的话可以输入关键词进行搜索哦~') +async def handle_shindan_name(bot: Bot, event: GroupMessageEvent, state: T_State): + global SHINDANMAKER_CACHE + + shindan_name = state['shindan_name'] + shindan_id = SHINDANMAKER_CACHE.get(shindan_name, 0) + if shindan_id == 0: + shindan_name_result = await ShindanMaker.search(keyword=shindan_name) + if shindan_name_result.error: + logger.error(f'User: {event.user_id} 获取 ShindanMaker 占卜信息失败, {shindan_name_result.info}') + await shindan_maker_default.finish('获取ShindanMaker占卜信息失败了QAQ, 请稍后再试') + else: + for item in shindan_name_result.result: + if item.get('name'): + SHINDANMAKER_CACHE.update({ + re.sub(r'\s', '', item.get('name')): item.get('id', 0) + }) + shindan_id = SHINDANMAKER_CACHE.get(shindan_name, 0) + if shindan_id == 0: + shindan_list = '】\n【'.join( + [re.sub(r'\s', '', x.get('name')) for x in shindan_name_result.result if x.get('name')]) + msg = f'搜索到了以下占卜\n{"="*12}\n【{shindan_list}】\n{"="*12}\n' \ + f'请使用占卜名称(方括号里面的完整名称)重新开始!' + await shindan_maker_default.finish(msg) + + state['id'] = shindan_id + + +@shindan_maker_default.got('input_name', prompt='请输入您想要进行占卜的人名:') +async def handle_input_name(bot: Bot, event: GroupMessageEvent, state: T_State): + shindan_name = state['shindan_name'] + input_name = state['input_name'] + shindan_id = state['id'] + today = f"@{datetime.date.today().strftime('%Y%m%d')}@" + # 加入日期使每天的结果不一样 + _input_name = f'{input_name}{today}' + result = await ShindanMaker(maker_id=shindan_id).get_result(input_name=_input_name) + if result.error: + logger.error(f'User: {event.user_id} 获取 ShindanMaker 占卜结果失败, {result.info}') + await shindan_maker_default.finish('获取ShindanMaker占卜结果失败了QAQ, 请稍后再试') + + result_text = result.result.replace(today, '') + msg = f'{shindan_name}@{input_name}\n{"="*16}\n{result_text}' + await shindan_maker_default.finish(msg) + + +shindan_pattern = r'^今天的?(.+?)是什么(样的)?(.+?)[??]?$' +shindan_maker_today_custom = shindan_maker.on_regex( + shindan_pattern, + rule=OmegaRules.has_group_command_permission() & OmegaRules.has_level_or_node(30, 'shindan_maker.basic') +) + + +@shindan_maker_today_custom.handle() +async def handle_shojo(bot: Bot, event: GroupMessageEvent, state: T_State): + # 固定的id + shindan_custon_id: Dict[str, int] = { + '少女': 162207, + '魔法少女': 828741, + '偶像': 828727, + '做的': 761425, + '干员': 959146, + '小动物': 828905, + '猫': 28998, + '主角': 828977, + '宝石': 890951, + '花': 829525 + } + + args = str(event.get_plaintext()).strip() + input_name, shindan_name = re.findall(shindan_pattern, args)[0][0], re.findall(shindan_pattern, args)[0][2] + shindan_id = shindan_custon_id.get(shindan_name, None) + if not shindan_id: + logger.info(f'User: {event.user_id} 获取 ShindanMaker 占卜结果被中止, 没有对应的预置占卜, {shindan_name} not found') + await shindan_maker_today_custom.finish( + f'没有你想问的东西哦, 或者你是想知道, 今天的XX是什么{"/".join(shindan_custon_id.keys())}吗?') + + today = f"@{datetime.date.today().strftime('%Y%m%d')}@" + # 加入日期使每天的结果不一样 + _input_name = f'{input_name}{today}' + result = await ShindanMaker(maker_id=shindan_id).get_result(input_name=_input_name) + if result.error: + logger.error(f'User: {event.user_id} 获取 ShindanMaker 占卜结果失败, {result.info}') + await shindan_maker_today_custom.finish('获取ShindanMaker占卜结果失败了QAQ, 请稍后再试') + + result_text = result.result.replace(today, '') + await shindan_maker_today_custom.finish(result_text) diff --git a/omega_miya/plugins/shindan_maker/data_source.py b/omega_miya/plugins/shindan_maker/data_source.py new file mode 100644 index 00000000..6fd7da06 --- /dev/null +++ b/omega_miya/plugins/shindan_maker/data_source.py @@ -0,0 +1,111 @@ +""" +@Author : Ailitonia +@Date : 2021/06/28 21:42 +@FileName : data_source.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + + +from nonebot import logger +from bs4 import BeautifulSoup +from omega_miya.utils.Omega_Base import Result +from omega_miya.utils.Omega_plugin_utils import HttpFetcher + + +class ShindanMaker(object): + ROOT_URL = 'https://cn.shindanmaker.com' + SEARCH_URL = f'{ROOT_URL}/list/search' + + HEADERS = {'accept': '*/*', + 'accept-encoding': 'gzip, deflate', + 'accept-language': 'zh-CN,zh;q=0.9', + 'cache-control': 'max-age=0', + 'dnt': '1', + 'sec-ch-ua': '"Google Chrome";v="89", "Chromium";v="89", ";Not A Brand";v="99"', + 'sec-ch-ua-mobile': '?0', + 'sec-fetch-dest': 'document', + 'sec-fetch-mode': 'navigate', + 'sec-fetch-site': 'none', + 'sec-fetch-user': '?1', + 'sec-gpc': '1', + 'upgrade-insecure-requests': '1', + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/91.0.4472.124 Safari/537.36'} + + def __init__(self, maker_id: int): + self.maker_id = maker_id + + @classmethod + async def search(cls, keyword: str) -> Result.DictListResult: + fetcher = HttpFetcher(timeout=10, flag='shindanmaker_search', headers=cls.HEADERS) + html_result = await fetcher.get_text(url=cls.SEARCH_URL, params={'q': keyword}) + if html_result.error: + return Result.DictListResult(error=True, info=f'Fetch search result failed, {html_result.info}', result=[]) + + try: + _bs = BeautifulSoup(html_result.result, 'lxml') + all_result = _bs.find_all(name='a', attrs={'class': 'shindanLink'}) + result = [] + for item in all_result: + _url = item.attrs['href'] + _id = int(str(_url).replace(f'{cls.ROOT_URL}/', '')) + _name = item.get_text(strip=True) + result.append({ + 'id': _id, + 'url': _url, + 'name': _name + }) + return Result.DictListResult(error=False, info='Success', result=result) + except Exception as e: + logger.error(f'ShindanMaker | Parse search result failed, error: {repr(e)}') + return Result.DictListResult(error=True, info=f'Parse search result failed', result=[]) + + async def get_result(self, input_name: str) -> Result.TextResult: + fetcher = HttpFetcher(timeout=10, flag='shindanmaker_get_token', headers=self.HEADERS) + url = f'{self.ROOT_URL}/{self.maker_id}' + + html_result = await fetcher.get_text(url=url) + if html_result.error: + return Result.TextResult(error=True, info=f'Fetch shindan_maker page failed, {html_result.info}', result='') + elif html_result.status == 404: + return Result.TextResult(error=True, info=f'Shindan_maker page not found, 404 error', result='') + + try: + _bs = BeautifulSoup(html_result.result, 'lxml') + _input_form = _bs.find(name='form', attrs={'id': 'shindanForm', 'method': 'POST'}) + _token = _input_form.find(name='input', attrs={'type': 'hidden', 'name': '_token'}).attrs['value'] + except Exception as e: + logger.error(f'ShindanMaker | Parse page token failed, error: {repr(e)}') + return Result.TextResult(error=True, info=f'Parse page token failed', result='') + + _header = self.HEADERS.update({ + 'content-type': 'application/x-www-form-urlencoded', + 'origin': self.ROOT_URL, + 'referer': f'{self.ROOT_URL}/{self.maker_id}', + 'sec-fetch-site': 'same-origin', + 'upgrade-insecure-requests': '1' + }) + fetcher = HttpFetcher(timeout=10, flag='shindanmaker_get_result', cookies=html_result.cookies, headers=_header) + data = fetcher.FormData() + data.add_field(name='_token', value=_token) + data.add_field(name='shindanName', value=input_name) + data.add_field(name='hiddenName', value='无名的X') + maker_result = await fetcher.post_text(url=url, data=data) + + if maker_result.error: + return Result.TextResult( + error=True, info=f'Fetch shindan_maker result failed, {maker_result.info}', result='') + + try: + _bs = BeautifulSoup(maker_result.result, 'lxml') + _result = _bs.find(name='span', attrs={'id': 'shindanResult'}) + for line_break in _result.findAll(name='br'): + line_break.replaceWith('\n') + _result = _result.get_text() + return Result.TextResult(error=False, info='Success', result=_result) + except Exception as e: + logger.error(f'ShindanMaker | Parse result page failed, error: {repr(e)}') + return Result.TextResult(error=True, info=f'Parse result page failed', result='') diff --git a/omega_miya/plugins/sticker_maker/__init__.py b/omega_miya/plugins/sticker_maker/__init__.py index d856b3aa..d78f1887 100644 --- a/omega_miya/plugins/sticker_maker/__init__.py +++ b/omega_miya/plugins/sticker_maker/__init__.py @@ -1,4 +1,5 @@ import re +import pathlib from nonebot import on_command, export, logger from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot @@ -18,12 +19,21 @@ **Permission** Friend Private Command & Lv.10 +or AuthNode + +**AuthNode** +basic **Usage** /表情包 [模板名]''' +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'basic' +] + # Init plugin export -init_export(export(), __plugin_name__, __plugin_usage__) +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) sticker = on_command( @@ -33,7 +43,8 @@ state=init_permission_state( name='sticker', command=True, - level=10), + level=10, + auth_node='basic'), permission=GROUP | PRIVATE_FRIEND, priority=10, block=True) @@ -68,10 +79,13 @@ async def handle_sticker(bot: Bot, event: MessageEvent, state: T_State): '默认': {'name': 'default', 'type': 'default', 'text_part': 1, 'help_msg': '该模板不支持gif'}, '白底': {'name': 'whitebg', 'type': 'default', 'text_part': 1, 'help_msg': '该模板不支持gif'}, '黑框': {'name': 'blackbg', 'type': 'default', 'text_part': 1, 'help_msg': '该模板不支持gif'}, + '黑白': {'name': 'decolorize', 'type': 'default', 'text_part': 0, 'help_msg': '该模板不支持gif'}, + '生草日语': {'name': 'grassja', 'type': 'default', 'text_part': 1, 'help_msg': '该模板不支持gif'}, '小天使': {'name': 'littleangel', 'type': 'default', 'text_part': 1, 'help_msg': '该模板不支持gif'}, '有内鬼': {'name': 'traitor', 'type': 'static', 'text_part': 1, 'help_msg': '该模板字数限制100(x)'}, '记仇': {'name': 'jichou', 'type': 'static', 'text_part': 1, 'help_msg': '该模板字数限制100(x)'}, - 'ph': {'name': 'phlogo', 'type': 'static', 'text_part': 1, 'help_msg': '两部分文字中间请用空格隔开'} + 'ph': {'name': 'phlogo', 'type': 'static', 'text_part': 1, 'help_msg': '两部分文字中间请用空格隔开'}, + 'petpet': {'name': 'petpet', 'type': 'gif', 'text_part': 0, 'help_msg': '最好使用长宽比接近正方形的图片'} } get_sticker_temp = state['temp'] @@ -86,23 +100,28 @@ async def handle_sticker(bot: Bot, event: MessageEvent, state: T_State): state['temp_help_msg'] = sticker_temp[get_sticker_temp]['help_msg'] # 判断该模板表情图片来源 - if state['temp_type'] in ['static', 'gif']: + if state['temp_type'] in ['static']: state['image_url'] = None + # 判断是否需要文字 + if state['temp_text_part'] == 0: + state['sticker_text'] = '' + @sticker.got('image_url', prompt='请发送你想要制作的表情包的图片:') async def handle_img(bot: Bot, event: MessageEvent, state: T_State): image_url = state['image_url'] - if state['temp_type'] not in ['static', 'gif']: - if not re.match(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=.+])', image_url): - await sticker.reject('你发送的似乎不是图片呢, 请重新发送, 取消命令请发送【取消】:') - + if state['temp_type'] not in ['static']: # 提取图片url - image_url = re.sub(r'^(\[CQ:image,file=[abcdef\d]{32}\.image,url=)', '', image_url) - image_url = re.sub(r'(])$', '', image_url) - + image_url = None + for msg_seg in event.message: + if msg_seg.type == 'image': + image_url = msg_seg.data.get('url') + break + # 没有提取到图片url + if not image_url: + await sticker.reject('你发送的似乎不是图片呢, 请重新发送, 取消命令请发送【取消】:') state['image_url'] = image_url - state['sticker_text'] = None @sticker.got('sticker_text', prompt='请输入你想要制作的表情包的文字:') @@ -120,17 +139,21 @@ async def handle_sticker_text(bot: Bot, event: MessageEvent, state: T_State): f'\n\n注意: 请用【#】号分割文本不同段落,不同模板适用的文字字数及段落数有所区别' else: text_msg = f'请输入你想要制作的表情包的文字: \n注意: 不同模板适用的文字字数有所区别' - if not sticker_text: - await sticker.reject(text_msg) - # 过滤CQ码 - if re.match(r'\[CQ:', sticker_text, re.I): - await sticker.finish('含非法字符QAQ') + if sticker_temp_text_part == 0: + pass + else: + if not sticker_text: + await sticker.reject(text_msg) + + # 过滤CQ码 + if re.match(r'\[CQ:', sticker_text, re.I): + await sticker.finish('含非法字符QAQ') - if len(sticker_text.strip().split('#')) != sticker_temp_text_part: - eg_msg = r'我就是饿死#死外边 从这里跳下去#也不会吃你们一点东西#真香' - await sticker.finish(f"表情制作失败QAQ, 文本分段数错误\n" - f"当前模板文本分段数:【{sticker_temp_text_part}】\n\n示例: \n{eg_msg}") + if len(sticker_text.strip().split('#')) != sticker_temp_text_part: + eg_msg = r'我就是饿死#死外边 从这里跳下去#也不会吃你们一点东西#真香' + await sticker.finish(f"表情制作失败QAQ, 文本分段数错误\n" + f"当前模板文本分段数:【{sticker_temp_text_part}】\n\n示例: \n{eg_msg}") sticker_image_url = state['image_url'] sticker_temp_name = state['temp_name'] @@ -149,7 +172,8 @@ async def handle_sticker_text(bot: Bot, event: MessageEvent, state: T_State): # sticker_seg = MessageSegment.image(sticker_b64) # 直接用文件构造消息段 - sticker_seg = MessageSegment.image(f'file:///{sticker_path}') + file_url = pathlib.Path(sticker_path).as_uri() + sticker_seg = MessageSegment.image(file=file_url) # 发送图片 await sticker.send(sticker_seg) diff --git a/omega_miya/plugins/sticker_maker/utils/__init__.py b/omega_miya/plugins/sticker_maker/utils/__init__.py index 113b5938..4f8d1b1b 100644 --- a/omega_miya/plugins/sticker_maker/utils/__init__.py +++ b/omega_miya/plugins/sticker_maker/utils/__init__.py @@ -1,4 +1,6 @@ import os +import aiofiles +from typing import Optional from io import BytesIO from datetime import datetime from omega_miya.utils.Omega_plugin_utils import HttpFetcher @@ -6,6 +8,7 @@ from PIL import Image from .default_render import * from .static_render import * +from .gif_render import * global_config = get_driver().config TMP_PATH = global_config.tmp_path_ @@ -14,16 +17,47 @@ 'Chrome/89.0.4389.114 Safari/537.36'} -async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: str): +async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: str) -> Optional[str]: # 定义表情包处理函数 stick_maker = { 'default': stick_maker_temp_default, 'whitebg': stick_maker_temp_whitebg, 'blackbg': stick_maker_temp_blackbg, + 'decolorize': stick_maker_temp_decolorize, + 'grassja': stick_maker_temp_grass_ja, 'littleangel': stick_maker_temp_littleangel, 'traitor': stick_maker_static_traitor, 'jichou': stick_maker_static_jichou, - 'phlogo': stick_maker_static_phlogo + 'phlogo': stick_maker_static_phlogo, + 'petpet': stick_maker_temp_petpet + } + + # 定义表情包模板使用字体 + sticker_default_font = { + 'default': 'msyhbd.ttc', + 'whitebg': 'msyhbd.ttc', + 'blackbg': 'msyhbd.ttc', + 'decolorize': 'msyhbd.ttc', + 'grassja': 'fzzxhk.ttf', + 'littleangel': 'msyhbd.ttc', + 'traitor': 'pixel.ttf', + 'jichou': 'SourceHanSans_Regular.otf', + 'phlogo': 'SourceHanSans_Heavy.otf', + 'petpet': 'SourceHanSans_Regular.otf' + } + + # 定义表情包模板默认宽度 + sticker_default_width = { + 'default': 512, + 'whitebg': 512, + 'blackbg': 512, + 'decolorize': 512, + 'grassja': 800, + 'littleangel': 512, + 'traitor': 512, + 'jichou': 512, + 'phlogo': 512, + 'petpet': 512 } # 检查生成表情包路径 @@ -34,6 +68,16 @@ async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: # 插件路径 plugin_src_path = os.path.abspath(os.path.dirname(__file__)) + # 字体路径 + font_path = os.path.join(plugin_src_path, 'fonts', sticker_default_font.get(temp)) + # 检查预置字体 + if not os.path.exists(font_path): + logger.error(f"Stick_maker: 模板预置文件错误, 字体{sticker_default_font.get(temp)}不存在") + return None + + # 表情包宽度 + sticker_width = sticker_default_width.get(temp, 512) + # 默认模式 if sticker_temp_type == 'default': fetcher = HttpFetcher(timeout=10, flag='sticker_maker_main_default', headers=HEADERS) @@ -45,23 +89,21 @@ async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: image_bytes_f = BytesIO() image_bytes_f.write(image_result.result) - # 字体路径 - font_path = os.path.join(plugin_src_path, 'fonts', 'msyhbd.ttc') # 生成表情包路径 sticker_path = os.path.abspath( os.path.join(sticker_folder_path, f"{temp}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.jpg")) # 调整图片大小(宽度512像素) make_image = Image.open(image_bytes_f) - image_resize_width = 512 - image_resize_height = 512 * make_image.height // make_image.width + image_resize_width = sticker_width + image_resize_height = sticker_width * make_image.height // make_image.width make_image = make_image.resize((image_resize_width, image_resize_height)) # 调用模板处理图片 - make_image = stick_maker[temp](text=text, image_file=make_image, font_path=font_path, - image_wight=image_resize_width, image_height=image_resize_height) + made_image = await stick_maker[temp](text=text, image_file=make_image, font_path=font_path, + image_wight=image_resize_width, image_height=image_resize_height) # 输出图片 - make_image.save(sticker_path, 'JPEG') + made_image.save(sticker_path, 'JPEG') image_bytes_f.close() return sticker_path @@ -77,17 +119,6 @@ async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: return None bg_image_path = os.path.join(static_temp_path, 'default_bg.png') - # 检查预置字体 - if os.path.exists(os.path.join(static_temp_path, 'default_font.ttc')): - font_path = os.path.join(static_temp_path, 'default_font.ttc') - elif os.path.exists(os.path.join(static_temp_path, 'default_font.ttf')): - font_path = os.path.join(static_temp_path, 'default_font.ttf') - elif os.path.exists(os.path.join(static_temp_path, 'default_font.otf')): - font_path = os.path.join(static_temp_path, 'default_font.otf') - else: - logger.error(f'Stick_maker: 模板预置文件错误, 默认字体应为default_font.ttc、default_font.ttf或default_font.otf') - return None - # 生成表情包路径 sticker_path = os.path.abspath( os.path.join(sticker_folder_path, f"{temp}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.jpg")) @@ -97,13 +128,48 @@ async def sticker_maker_main(url: str, temp: str, text: str, sticker_temp_type: (image_resize_width, image_resize_height) = make_image.size # 调用模板处理图片 - make_image = stick_maker[temp](text=text, image_file=make_image, font_path=font_path, - image_wight=image_resize_width, image_height=image_resize_height) + make_image = await stick_maker[temp](text=text, image_file=make_image, font_path=font_path, + image_wight=image_resize_width, image_height=image_resize_height) # 输出图片 make_image.save(sticker_path, 'JPEG') return sticker_path + # gif模板模式 + elif sticker_temp_type == 'gif': + # 模板路径 + gif_temp_path = os.path.abspath(os.path.join(plugin_src_path, 'gif_template', temp)) + + fetcher = HttpFetcher(timeout=10, flag='sticker_maker_main_default', headers=HEADERS) + image_result = await fetcher.get_bytes(url=url) + if image_result.error: + logger.error(f'Stick_maker download image failed: {image_result.info}') + return None + + # 生成表情包路径 + sticker_path = os.path.abspath( + os.path.join(sticker_folder_path, f"{temp}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.gif")) + + with BytesIO() as image_bytes_f: + image_bytes_f.write(image_result.result) + + # 调整图片大小 + make_image = Image.open(image_bytes_f) + image_resize_width = sticker_width + image_resize_height = image_resize_width * make_image.height // make_image.width + make_image = make_image.resize((image_resize_width, image_resize_height)) + + # 调用模板处理图片 + made_image = await stick_maker[temp](text=text, image_file=make_image, font_path=font_path, + image_wight=image_resize_width, image_height=image_resize_height, + temp_path=gif_temp_path) + + if not made_image: + return None + else: + async with aiofiles.open(sticker_path, 'wb+') as af: + await af.write(made_image) + return sticker_path else: return None diff --git a/omega_miya/plugins/sticker_maker/utils/default_render.py b/omega_miya/plugins/sticker_maker/utils/default_render.py index 4a6641bd..54804cd8 100644 --- a/omega_miya/plugins/sticker_maker/utils/default_render.py +++ b/omega_miya/plugins/sticker_maker/utils/default_render.py @@ -1,152 +1,249 @@ -from PIL import Image, ImageDraw, ImageFont - - -def stick_maker_temp_default(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理图片 - draw = ImageDraw.Draw(image_file) - font_size = 72 - font = ImageFont.truetype(font_path, font_size) - text_w, text_h = font.getsize_multiline(text) - while text_w >= image_wight: - font_size = font_size * 3 // 4 +import asyncio +from PIL import Image, ImageDraw, ImageFont, ImageEnhance +from omega_miya.utils.tencent_cloud_api import TencentTMT + + +async def stick_maker_temp_default( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 默认加字表情包模板 + """ + def __handle() -> Image.Image: + # 处理图片 + draw = ImageDraw.Draw(image_file) + font_size = 72 font = ImageFont.truetype(font_path, font_size) text_w, text_h = font.getsize_multiline(text) - # 计算居中文字位置 - text_coordinate = (((image_wight - text_w) // 2), 9 * (image_height - text_h) // 10) - # 为文字设置黑边 - text_b_resize = 4 - if font_size >= 72: + while text_w >= image_wight: + font_size = font_size * 3 // 4 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize_multiline(text) + # 计算居中文字位置 + text_coordinate = (((image_wight - text_w) // 2), 9 * (image_height - text_h) // 10) + # 为文字设置黑边 text_b_resize = 4 - elif font_size >= 36: - text_b_resize = 3 - elif font_size >= 24: - text_b_resize = 2 - elif font_size < 12: - text_b_resize = 1 - text_coordinate_b1 = (text_coordinate[0] + text_b_resize, text_coordinate[1]) - text_coordinate_b2 = (text_coordinate[0] - text_b_resize, text_coordinate[1]) - text_coordinate_b3 = (text_coordinate[0], text_coordinate[1] + text_b_resize) - text_coordinate_b4 = (text_coordinate[0], text_coordinate[1] - text_b_resize) - draw.multiline_text(text_coordinate_b1, text, font=font, fill=(0, 0, 0)) - draw.multiline_text(text_coordinate_b2, text, font=font, fill=(0, 0, 0)) - draw.multiline_text(text_coordinate_b3, text, font=font, fill=(0, 0, 0)) - draw.multiline_text(text_coordinate_b4, text, font=font, fill=(0, 0, 0)) - # 白字要后画,后画的在上层,不然就是黑滋在上面挡住了 - draw.multiline_text(text_coordinate, text, font=font, fill=(255, 255, 255)) - return image_file - - -def stick_maker_temp_littleangel(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理图片 - background_w = image_wight + 100 - background_h = image_height + 230 - background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) - # 处理粘贴位置 上留100像素,下留130像素 - image_coordinate = (((background_w - image_wight) // 2), 100) - background.paste(image_file, image_coordinate) - draw = ImageDraw.Draw(background) - - font_down_1 = ImageFont.truetype(font_path, 48) - text_down_1 = r'非常可爱!简直就是小天使' - text_down_1_w, text_down_1_h = font_down_1.getsize(text_down_1) - text_down_1_coordinate = (((background_w - text_down_1_w) // 2), background_h - 120) - draw.text(text_down_1_coordinate, text_down_1, font=font_down_1, fill=(0, 0, 0)) - - font_down_2 = ImageFont.truetype(font_path, 26) - text_down_2 = r'她没失踪也没怎么样 我只是觉得你们都该看一下' - text_down_2_w, text_down_2_h = font_down_2.getsize(text_down_2) - text_down_2_coordinate = (((background_w - text_down_2_w) // 2), background_h - 60) - draw.text(text_down_2_coordinate, text_down_2, font=font_down_2, fill=(0, 0, 0)) - - font_size_up = 72 - font_up = ImageFont.truetype(font_path, font_size_up) - text_up = f'请问你们看到{text}了吗?' - text_up_w, text_up_h = font_up.getsize(text_up) - while text_up_w >= background_w: - font_size_up = font_size_up * 5 // 6 + if font_size >= 72: + text_b_resize = 4 + elif font_size >= 36: + text_b_resize = 3 + elif font_size >= 24: + text_b_resize = 2 + elif font_size < 12: + text_b_resize = 1 + text_coordinate_b1 = (text_coordinate[0] + text_b_resize, text_coordinate[1]) + text_coordinate_b2 = (text_coordinate[0] - text_b_resize, text_coordinate[1]) + text_coordinate_b3 = (text_coordinate[0], text_coordinate[1] + text_b_resize) + text_coordinate_b4 = (text_coordinate[0], text_coordinate[1] - text_b_resize) + draw.multiline_text(text_coordinate_b1, text, font=font, fill=(0, 0, 0)) + draw.multiline_text(text_coordinate_b2, text, font=font, fill=(0, 0, 0)) + draw.multiline_text(text_coordinate_b3, text, font=font, fill=(0, 0, 0)) + draw.multiline_text(text_coordinate_b4, text, font=font, fill=(0, 0, 0)) + # 白字要后画,后画的在上层,不然就是黑滋在上面挡住了 + draw.multiline_text(text_coordinate, text, font=font, fill=(255, 255, 255)) + return image_file + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_temp_littleangel( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 小天使表情包模板 + """ + def __handle() -> Image.Image: + # 处理图片 + background_w = image_wight + 100 + background_h = image_height + 230 + background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) + # 处理粘贴位置 上留100像素,下留130像素 + image_coordinate = (((background_w - image_wight) // 2), 100) + background.paste(image_file, image_coordinate) + draw = ImageDraw.Draw(background) + + font_down_1 = ImageFont.truetype(font_path, 48) + text_down_1 = r'非常可爱!简直就是小天使' + text_down_1_w, text_down_1_h = font_down_1.getsize(text_down_1) + text_down_1_coordinate = (((background_w - text_down_1_w) // 2), background_h - 120) + draw.text(text_down_1_coordinate, text_down_1, font=font_down_1, fill=(0, 0, 0)) + + font_down_2 = ImageFont.truetype(font_path, 26) + text_down_2 = r'她没失踪也没怎么样 我只是觉得你们都该看一下' + text_down_2_w, text_down_2_h = font_down_2.getsize(text_down_2) + text_down_2_coordinate = (((background_w - text_down_2_w) // 2), background_h - 60) + draw.text(text_down_2_coordinate, text_down_2, font=font_down_2, fill=(0, 0, 0)) + + font_size_up = 72 font_up = ImageFont.truetype(font_path, font_size_up) + text_up = f'请问你们看到{text}了吗?' text_up_w, text_up_h = font_up.getsize(text_up) - # 计算居中文字位置 - text_up_coordinate = (((background_w - text_up_w) // 2), 25) - draw.text(text_up_coordinate, text_up, font=font_up, fill=(0, 0, 0)) - return background - - -def stick_maker_temp_whitebg(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理文本 - if image_wight > image_height: - font_size = 72 - else: - font_size = 84 - font = ImageFont.truetype(font_path, font_size) - text_w, text_h = font.getsize_multiline(text) - while text_w >= (image_wight * 8 // 9): - font_size = font_size * 7 // 8 + while text_up_w >= background_w: + font_size_up = font_size_up * 5 // 6 + font_up = ImageFont.truetype(font_path, font_size_up) + text_up_w, text_up_h = font_up.getsize(text_up) + # 计算居中文字位置 + text_up_coordinate = (((background_w - text_up_w) // 2), 25) + draw.text(text_up_coordinate, text_up, font=font_up, fill=(0, 0, 0)) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_temp_whitebg( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 白底加字表情包模板 + """ + def __handle() -> Image.Image: + # 处理文本 + if image_wight > image_height: + font_size = 72 + else: + font_size = 84 font = ImageFont.truetype(font_path, font_size) text_w, text_h = font.getsize_multiline(text) - - # 处理图片 - background_w = image_wight - background_h = image_height + round(text_h * 1.5) - background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) - - # 处理粘贴位置 顶头 - image_coordinate = (0, 0) - background.paste(image_file, image_coordinate) - - draw = ImageDraw.Draw(background) - - # 计算居中文字位置 - text_coordinate = (((background_w - text_w) // 2), image_height + round(text_h / 100) * round(text_h * 0.1)) - - draw.multiline_text(text_coordinate, text, font=font, fill=(0, 0, 0)) - - return background - - -def stick_maker_temp_blackbg(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理文本 - if image_wight > image_height: - font_size = 96 - else: - font_size = 108 - font = ImageFont.truetype(font_path, font_size) - text_w, text_h = font.getsize_multiline(text) - while text_w >= (image_wight * 9 // 10): - font_size = font_size * 8 // 9 + while text_w >= (image_wight * 8 // 9): + font_size = font_size * 7 // 8 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize_multiline(text) + + # 处理图片 + background_w = image_wight + background_h = image_height + round(text_h * 1.5) + background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) + + # 处理粘贴位置 顶头 + image_coordinate = (0, 0) + background.paste(image_file, image_coordinate) + + draw = ImageDraw.Draw(background) + # 计算居中文字位置 + text_coordinate = (((background_w - text_w) // 2), image_height + round(text_h / 100) * round(text_h * 0.1)) + draw.multiline_text(text_coordinate, text, font=font, fill=(0, 0, 0)) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_temp_blackbg( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 黑边加底字表情包模板 + """ + def __handle() -> Image.Image: + # 处理文本 + if image_wight > image_height: + font_size = 96 + else: + font_size = 108 font = ImageFont.truetype(font_path, font_size) text_w, text_h = font.getsize_multiline(text) + while text_w >= (image_wight * 9 // 10): + font_size = font_size * 8 // 9 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize_multiline(text) + + # 处理图片 + background_w = image_wight + 150 + background_h = image_height + 115 + round(text_h * 1.5) + background = Image.new(mode="RGB", size=(background_w, background_h), color=(0, 0, 0)) + layer_1 = Image.new(mode="RGB", size=(image_wight + 12, image_height + 12), color=(255, 255, 255)) + layer_2 = Image.new(mode="RGB", size=(image_wight + 10, image_height + 10), color=(0, 0, 0)) + layer_3 = Image.new(mode="RGB", size=(image_wight + 6, image_height + 6), color=(255, 255, 255)) + layer_4 = Image.new(mode="RGB", size=(image_wight + 4, image_height + 4), color=(0, 0, 0)) + + # 处理粘贴位置 留出黑边距离 + background.paste(layer_1, (70, 70)) + background.paste(layer_2, (71, 71)) + background.paste(layer_3, (73, 73)) + background.paste(layer_4, (74, 74)) + background.paste(image_file, (76, 76)) + + draw = ImageDraw.Draw(background) + + # 计算居中文字位置 + text_coordinate = (((background_w - text_w) // 2), + image_height + 110 - round(text_h / 9) + round(text_h / 100) * round(text_h * 0.1)) + draw.multiline_text(text_coordinate, text, font=font, fill=(255, 255, 255)) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_temp_decolorize( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 去色加字表情包模板 + """ + def __handle() -> Image.Image: + enhancer = ImageEnhance.Color(image_file) + made_image = enhancer.enhance(0) + return made_image + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_temp_grass_ja( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 生草日语表情包模板 + """ + # 处理日语翻译 + text_zh = text.replace('\n', ' ') + text_trans_result = await TencentTMT().translate(source_text=text, target='ja') + text_ja = str(text_trans_result.result.get('targettext', '翻訳に失敗しました!')).replace('\n', ' ') + text_ = f'{text_zh}\n{text_ja}' + + # 处理黑白 + image_file_ = await stick_maker_temp_decolorize(text, image_file, font_path, image_wight, image_height) + + def __handle() -> Image.Image: + # 处理文本 + if image_wight > image_height: + font_size = 48 + else: + font_size = 60 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize_multiline(text_) + while text_w >= (image_wight * 9 // 10): + font_size = font_size * 8 // 9 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize_multiline(text_) - # 处理图片 - background_w = image_wight + 150 - background_h = image_height + 115 + round(text_h * 1.5) - background = Image.new(mode="RGB", size=(background_w, background_h), color=(0, 0, 0)) - layer_1 = Image.new(mode="RGB", size=(image_wight + 12, image_height + 12), color=(255, 255, 255)) - layer_2 = Image.new(mode="RGB", size=(image_wight + 10, image_height + 10), color=(0, 0, 0)) - layer_3 = Image.new(mode="RGB", size=(image_wight + 6, image_height + 6), color=(255, 255, 255)) - layer_4 = Image.new(mode="RGB", size=(image_wight + 4, image_height + 4), color=(0, 0, 0)) - - # 处理粘贴位置 留出黑边距离 - background.paste(layer_1, (70, 70)) - background.paste(layer_2, (71, 71)) - background.paste(layer_3, (73, 73)) - background.paste(layer_4, (74, 74)) - background.paste(image_file, (76, 76)) + # 处理图片 + background_w = image_wight + background_h = image_height + round(text_h * 1.5) + background = Image.new(mode="RGB", size=(background_w, background_h), color=(0, 0, 0)) - draw = ImageDraw.Draw(background) + # 处理粘贴位置 留出黑边距离 + background.paste(image_file_, (0, 0)) - # 计算居中文字位置 - text_coordinate = (((background_w - text_w) // 2), - image_height + 110 - round(text_h / 9) + round(text_h / 100) * round(text_h * 0.1)) + draw = ImageDraw.Draw(background) - draw.multiline_text(text_coordinate, text, font=font, fill=(255, 255, 255)) + # 计算居中文字位置 + text_coordinate = (((background_w - text_w) // 2), image_height + round(text_h * 0.2)) + draw.multiline_text(text_coordinate, text_, font=font, align='center', fill=(255, 255, 255)) + return background - return background + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result __all__ = [ 'stick_maker_temp_whitebg', 'stick_maker_temp_blackbg', 'stick_maker_temp_default', - 'stick_maker_temp_littleangel' + 'stick_maker_temp_littleangel', + 'stick_maker_temp_decolorize', + 'stick_maker_temp_grass_ja' ] diff --git a/omega_miya/plugins/sticker_maker/utils/fonts/K_Gothic.ttf b/omega_miya/plugins/sticker_maker/utils/fonts/K_Gothic.ttf new file mode 100644 index 00000000..fa067bb0 Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/fonts/K_Gothic.ttf differ diff --git a/omega_miya/plugins/sticker_maker/utils/static/phlogo/default_font.otf b/omega_miya/plugins/sticker_maker/utils/fonts/SourceHanSans_Heavy.otf similarity index 100% rename from omega_miya/plugins/sticker_maker/utils/static/phlogo/default_font.otf rename to omega_miya/plugins/sticker_maker/utils/fonts/SourceHanSans_Heavy.otf diff --git a/omega_miya/plugins/sticker_maker/utils/static/jichou/default_font.ttc b/omega_miya/plugins/sticker_maker/utils/fonts/SourceHanSans_Regular.otf similarity index 56% rename from omega_miya/plugins/sticker_maker/utils/static/jichou/default_font.ttc rename to omega_miya/plugins/sticker_maker/utils/fonts/SourceHanSans_Regular.otf index 980ec434..88e3a354 100644 Binary files a/omega_miya/plugins/sticker_maker/utils/static/jichou/default_font.ttc and b/omega_miya/plugins/sticker_maker/utils/fonts/SourceHanSans_Regular.otf differ diff --git a/omega_miya/plugins/sticker_maker/utils/fonts/fzzxhk.ttf b/omega_miya/plugins/sticker_maker/utils/fonts/fzzxhk.ttf new file mode 100644 index 00000000..3d55a8dd Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/fonts/fzzxhk.ttf differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_render.py b/omega_miya/plugins/sticker_maker/utils/gif_render.py new file mode 100644 index 00000000..3414cf60 --- /dev/null +++ b/omega_miya/plugins/sticker_maker/utils/gif_render.py @@ -0,0 +1,84 @@ +""" +@Author : Ailitonia +@Date : 2021/06/27 17:31 +@FileName : gif_render.py +@Project : nonebot2_miya +@Description : gif表情包生成模板 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import os +import asyncio +import imageio +from io import BytesIO +from typing import Optional +from PIL import Image + + +async def stick_maker_temp_petpet( + text: str, + image_file: Image.Image, + font_path: str, + image_wight: int, + image_height: int, + temp_path: str) -> Optional[bytes]: + """ + petpet 表情包模板 + """ + def __handle() -> Optional[bytes]: + bg0 = Image.new(mode="RGBA", size=(112, 112), color=(255, 255, 255)) + bg1 = Image.new(mode="RGBA", size=(112, 112), color=(255, 255, 255)) + bg2 = Image.new(mode="RGBA", size=(112, 112), color=(255, 255, 255)) + bg3 = Image.new(mode="RGBA", size=(112, 112), color=(255, 255, 255)) + bg4 = Image.new(mode="RGBA", size=(112, 112), color=(255, 255, 255)) + tp0 = Image.open(os.path.join(temp_path, 'template_p0.png')) + tp1 = Image.open(os.path.join(temp_path, 'template_p1.png')) + tp2 = Image.open(os.path.join(temp_path, 'template_p2.png')) + tp3 = Image.open(os.path.join(temp_path, 'template_p3.png')) + tp4 = Image.open(os.path.join(temp_path, 'template_p4.png')) + bg0.paste(image_file.resize((95, 95)), (12, 15)) + bg1.paste(image_file.resize((97, 80)), (11, 30)) + bg2.paste(image_file.resize((99, 70)), (10, 40)) + bg3.paste(image_file.resize((97, 75)), (11, 35)) + bg4.paste(image_file.resize((96, 90)), (11, 20)) + bg0.paste(tp0, (0, 0), mask=tp0) + bg1.paste(tp1, (0, 0), mask=tp1) + bg2.paste(tp2, (0, 0), mask=tp2) + bg3.paste(tp3, (0, 0), mask=tp3) + bg4.paste(tp4, (0, 0), mask=tp4) + + frames_list = [] + with BytesIO() as bf0: + bg0.save(bf0, format='PNG') + img_bytes = bf0.getvalue() + frames_list.append(imageio.imread(img_bytes)) + with BytesIO() as bf1: + bg1.save(bf1, format='PNG') + img_bytes = bf1.getvalue() + frames_list.append(imageio.imread(img_bytes)) + with BytesIO() as bf2: + bg2.save(bf2, format='PNG') + img_bytes = bf2.getvalue() + frames_list.append(imageio.imread(img_bytes)) + with BytesIO() as bf3: + bg3.save(bf3, format='PNG') + img_bytes = bf3.getvalue() + frames_list.append(imageio.imread(img_bytes)) + with BytesIO() as bf4: + bg4.save(bf4, format='PNG') + img_bytes = bf4.getvalue() + frames_list.append(imageio.imread(img_bytes)) + + with BytesIO() as bf: + imageio.mimsave(bf, frames_list, 'GIF', duration=0.06) + return bf.getvalue() + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +__all__ = [ + 'stick_maker_temp_petpet' +] diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template.gif b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template.gif new file mode 100644 index 00000000..75f3fbd3 Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template.gif differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p0.png b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p0.png new file mode 100644 index 00000000..506c5b4c Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p0.png differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p1.png b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p1.png new file mode 100644 index 00000000..d9028fef Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p1.png differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p2.png b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p2.png new file mode 100644 index 00000000..c8b987ad Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p2.png differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p3.png b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p3.png new file mode 100644 index 00000000..a8eb64f8 Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p3.png differ diff --git a/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p4.png b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p4.png new file mode 100644 index 00000000..917ce372 Binary files /dev/null and b/omega_miya/plugins/sticker_maker/utils/gif_template/petpet/template_p4.png differ diff --git a/omega_miya/plugins/sticker_maker/utils/static/traitor/default_font.ttf b/omega_miya/plugins/sticker_maker/utils/static/traitor/default_font.ttf deleted file mode 100644 index 7f6116c6..00000000 Binary files a/omega_miya/plugins/sticker_maker/utils/static/traitor/default_font.ttf and /dev/null differ diff --git a/omega_miya/plugins/sticker_maker/utils/static_render.py b/omega_miya/plugins/sticker_maker/utils/static_render.py index ccfee2a1..cfdb77a0 100644 --- a/omega_miya/plugins/sticker_maker/utils/static_render.py +++ b/omega_miya/plugins/sticker_maker/utils/static_render.py @@ -1,133 +1,159 @@ +import asyncio from PIL import Image, ImageDraw, ImageFont from datetime import date -def stick_maker_static_traitor(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 初始化背景图层 - background = Image.new(mode="RGB", size=(image_wight, image_height), color=(255, 255, 255)) - - # 处理文字层 字数部分 - text_num_img = Image.new(mode="RGBA", size=(image_wight, image_height), color=(0, 0, 0, 0)) - font_num_size = 48 - font_num = ImageFont.truetype(font_path, font_num_size) - ImageDraw.Draw(text_num_img).text(xy=(0, 0), text=f'{len(text)}/100', font=font_num, fill=(255, 255, 255)) - - # 处理文字层 主体部分 - text_main_img = Image.new(mode="RGBA", size=(image_wight, image_height), color=(0, 0, 0, 0)) - font_main_size = 54 - font_main = ImageFont.truetype(font_path, font_main_size) - # 按长度切分文本 - spl_num = 0 - spl_list = [] - for num in range(len(text)): - text_w = font_main.getsize_multiline(text[spl_num:num])[0] - if text_w >= 415: - spl_list.append(text[spl_num:num]) - spl_num = num - else: - spl_list.append(text[spl_num:]) - test_main_fin = '' - for item in spl_list: - test_main_fin += item + '\n' - ImageDraw.Draw(text_main_img).multiline_text(xy=(0, 0), text=test_main_fin, font=font_main, spacing=8, fill=(0, 0, 0)) - - # 处理文字部分旋转 - text_num_img = text_num_img.rotate(angle=-9, expand=True, resample=Image.BICUBIC, center=(0, 0)) - text_main_img = text_main_img.rotate(angle=-9.5, expand=True, resample=Image.BICUBIC, center=(0, 0)) - - # 向模板图片中置入文字图层 - background.paste(im=image_file, box=(0, 0)) - background.paste(im=text_num_img, box=(435, 140), mask=text_num_img) - background.paste(im=text_main_img, box=(130, 160), mask=text_main_img) - - return background - - -def stick_maker_static_jichou(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理文本主体 - text = f"今天是{date.today().strftime('%Y年%m月%d日')}{text}, 这个仇我先记下了" - font_main_size = 42 - font_main = ImageFont.truetype(font_path, font_main_size) - # 按长度切分文本 - spl_num = 0 - spl_list = [] - for num in range(len(text)): - text_w = font_main.getsize_multiline(text[spl_num:num])[0] - if text_w >= (image_wight * 7 // 8): - spl_list.append(text[spl_num:num]) - spl_num = num - else: - spl_list.append(text[spl_num:]) - text_main_fin = '\n'.join(spl_list) - - font = ImageFont.truetype(font_path, font_main_size) - text_w, text_h = font.getsize_multiline(text_main_fin) - - # 处理图片 - background_w = image_wight - background_h = image_height + text_h + 20 - background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) - - # 处理粘贴位置 顶头 - image_coordinate = (0, 0) - background.paste(image_file, image_coordinate) - - draw = ImageDraw.Draw(background) - - # 计算居中文字位置 - text_coordinate = (((background_w - text_w) // 2), image_height + 5) - - draw.multiline_text(text_coordinate, text_main_fin, font=font, fill=(0, 0, 0)) - - return background - - -def stick_maker_static_phlogo(text: str, image_file: bytes, font_path: str, image_wight: int, image_height: int): - # 处理文本主体 - test_sentences = text.strip().split(maxsplit=1) - white_text = test_sentences[0] - yellow_text = test_sentences[1] - - font_size = 640 - font = ImageFont.truetype(font_path, font_size) - text_w, text_h = font.getsize(text) - - y_text_w, y_text_h = font.getsize(yellow_text) - bg_y_text = Image.new(mode="RGB", size=(round(y_text_w * 1.1), round(text_h * 1.3)), color=(254, 154, 0)) - draw_y_text = ImageDraw.Draw(bg_y_text) - draw_y_text.text((round(y_text_w * 1.1) // 2, round(text_h * 1.3) // 2), yellow_text, anchor='mm', font=font, fill=(0, 0, 0)) - radii = 64 - # 画圆(用于分离4个角) - circle = Image.new('L', (radii * 2, radii * 2), 0) # 创建黑色方形 - draw_circle = ImageDraw.Draw(circle) - draw_circle.ellipse((0, 0, radii * 2, radii * 2), fill=255) # 黑色方形内切白色圆形 - # 原图转为带有alpha通道(表示透明程度) - bg_y_text = bg_y_text.convert("RGBA") - y_weight, y_height = bg_y_text.size - # 画4个角(将整圆分离为4个部分) - alpha = Image.new('L', bg_y_text.size, 255) # 与img同大小的白色矩形,L 表示黑白图 - alpha.paste(circle.crop((0, 0, radii, radii)), (0, 0)) # 左上角 - alpha.paste(circle.crop((radii, 0, radii * 2, radii)), (y_weight - radii, 0)) # 右上角 - alpha.paste(circle.crop((radii, radii, radii * 2, radii * 2)), (y_weight - radii, y_height - radii)) # 右下角 - alpha.paste(circle.crop((0, radii, radii, radii * 2)), (0, y_height - radii)) # 左下角 - bg_y_text.putalpha(alpha) # 白色区域透明可见,黑色区域不可见 - - w_text_w, w_text_h = font.getsize(white_text) - bg_w_text = Image.new(mode="RGB", size=(round(w_text_w * 1.05), round(text_h * 1.3)), color=(0, 0, 0)) - w_weight, w_height = bg_w_text.size - draw_w_text = ImageDraw.Draw(bg_w_text) - draw_w_text.text((round(w_text_w * 1.025) // 2, round(text_h * 1.3) // 2), white_text, anchor='mm', font=font, fill=(255, 255, 255)) - - text_bg = Image.new(mode="RGB", size=(w_weight + y_weight, y_height), color=(0, 0, 0)) - text_bg.paste(bg_w_text, (0, 0)) - text_bg.paste(bg_y_text, (round(w_text_w * 1.05), 0), mask=alpha) - t_weight, t_height = text_bg.size - - background = Image.new(mode="RGB", size=(round(t_weight * 1.2), round(t_height * 1.75)), color=(0, 0, 0)) - b_weight, b_height = background.size - background.paste(text_bg, ((b_weight - t_weight) // 2, (b_height - t_height) // 2)) - - return background +async def stick_maker_static_traitor( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 有内鬼表情包模板 + """ + def __handle() -> Image.Image: + # 初始化背景图层 + background = Image.new(mode="RGB", size=(image_wight, image_height), color=(255, 255, 255)) + + # 处理文字层 字数部分 + text_num_img = Image.new(mode="RGBA", size=(image_wight, image_height), color=(0, 0, 0, 0)) + font_num_size = 48 + font_num = ImageFont.truetype(font_path, font_num_size) + ImageDraw.Draw(text_num_img).text(xy=(0, 0), text=f'{len(text)}/100', font=font_num, fill=(255, 255, 255)) + + # 处理文字层 主体部分 + text_main_img = Image.new(mode="RGBA", size=(image_wight, image_height), color=(0, 0, 0, 0)) + font_main_size = 54 + font_main = ImageFont.truetype(font_path, font_main_size) + # 按长度切分文本 + spl_num = 0 + spl_list = [] + for num in range(len(text)): + text_w = font_main.getsize_multiline(text[spl_num:num])[0] + if text_w >= 415: + spl_list.append(text[spl_num:num]) + spl_num = num + else: + spl_list.append(text[spl_num:]) + test_main_fin = '' + for item in spl_list: + test_main_fin += item + '\n' + ImageDraw.Draw(text_main_img).multiline_text(xy=(0, 0), text=test_main_fin, font=font_main, spacing=8, + fill=(0, 0, 0)) + + # 处理文字部分旋转 + text_num_img = text_num_img.rotate(angle=-9, expand=True, resample=Image.BICUBIC, center=(0, 0)) + text_main_img = text_main_img.rotate(angle=-9.5, expand=True, resample=Image.BICUBIC, center=(0, 0)) + + # 向模板图片中置入文字图层 + background.paste(im=image_file, box=(0, 0)) + background.paste(im=text_num_img, box=(435, 140), mask=text_num_img) + background.paste(im=text_main_img, box=(130, 160), mask=text_main_img) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_static_jichou( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + 记仇表情包模板 + """ + def __handle() -> Image.Image: + # 处理文本主体 + text_ = f"今天是{date.today().strftime('%Y年%m月%d日')}{text}, 这个仇我先记下了" + font_main_size = 42 + font_main = ImageFont.truetype(font_path, font_main_size) + # 按长度切分文本 + spl_num = 0 + spl_list = [] + for num in range(len(text_)): + text_w = font_main.getsize_multiline(text_[spl_num:num])[0] + if text_w >= (image_wight * 7 // 8): + spl_list.append(text_[spl_num:num]) + spl_num = num + else: + spl_list.append(text_[spl_num:]) + text_main_fin = '\n'.join(spl_list) + + font = ImageFont.truetype(font_path, font_main_size) + text_w, text_h = font.getsize_multiline(text_main_fin) + + # 处理图片 + background_w = image_wight + background_h = image_height + text_h + 20 + background = Image.new(mode="RGB", size=(background_w, background_h), color=(255, 255, 255)) + + # 处理粘贴位置 顶头 + image_coordinate = (0, 0) + background.paste(image_file, image_coordinate) + + draw = ImageDraw.Draw(background) + # 计算居中文字位置 + text_coordinate = (((background_w - text_w) // 2), image_height + 5) + draw.multiline_text(text_coordinate, text_main_fin, font=font, fill=(0, 0, 0)) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result + + +async def stick_maker_static_phlogo( + text: str, image_file: Image.Image, font_path: str, image_wight: int, image_height: int) -> Image.Image: + """ + ph表情包模板 + """ + def __handle() -> Image.Image: + # 处理文本主体 + test_sentences = text.strip().split(maxsplit=1) + white_text = test_sentences[0] + yellow_text = test_sentences[1] + + font_size = 640 + font = ImageFont.truetype(font_path, font_size) + text_w, text_h = font.getsize(text) + + y_text_w, y_text_h = font.getsize(yellow_text) + bg_y_text = Image.new(mode="RGB", size=(round(y_text_w * 1.1), round(text_h * 1.3)), color=(254, 154, 0)) + draw_y_text = ImageDraw.Draw(bg_y_text) + draw_y_text.text((round(y_text_w * 1.1) // 2, round(text_h * 1.3) // 2), + yellow_text, anchor='mm', font=font, fill=(0, 0, 0)) + radii = 64 + # 画圆(用于分离4个角) + circle = Image.new('L', (radii * 2, radii * 2), 0) # 创建黑色方形 + draw_circle = ImageDraw.Draw(circle) + draw_circle.ellipse((0, 0, radii * 2, radii * 2), fill=255) # 黑色方形内切白色圆形 + # 原图转为带有alpha通道(表示透明程度) + bg_y_text = bg_y_text.convert("RGBA") + y_weight, y_height = bg_y_text.size + # 画4个角(将整圆分离为4个部分) + alpha = Image.new('L', bg_y_text.size, 255) # 与img同大小的白色矩形,L 表示黑白图 + alpha.paste(circle.crop((0, 0, radii, radii)), (0, 0)) # 左上角 + alpha.paste(circle.crop((radii, 0, radii * 2, radii)), (y_weight - radii, 0)) # 右上角 + alpha.paste(circle.crop((radii, radii, radii * 2, radii * 2)), (y_weight - radii, y_height - radii)) # 右下角 + alpha.paste(circle.crop((0, radii, radii, radii * 2)), (0, y_height - radii)) # 左下角 + bg_y_text.putalpha(alpha) # 白色区域透明可见,黑色区域不可见 + + w_text_w, w_text_h = font.getsize(white_text) + bg_w_text = Image.new(mode="RGB", size=(round(w_text_w * 1.05), round(text_h * 1.3)), color=(0, 0, 0)) + w_weight, w_height = bg_w_text.size + draw_w_text = ImageDraw.Draw(bg_w_text) + draw_w_text.text((round(w_text_w * 1.025) // 2, round(text_h * 1.3) // 2), + white_text, anchor='mm', font=font, fill=(255, 255, 255)) + + text_bg = Image.new(mode="RGB", size=(w_weight + y_weight, y_height), color=(0, 0, 0)) + text_bg.paste(bg_w_text, (0, 0)) + text_bg.paste(bg_y_text, (round(w_text_w * 1.05), 0), mask=alpha) + t_weight, t_height = text_bg.size + + background = Image.new(mode="RGB", size=(round(t_weight * 1.2), round(t_height * 1.75)), color=(0, 0, 0)) + b_weight, b_height = background.size + background.paste(text_bg, ((b_weight - t_weight) // 2, (b_height - t_height) // 2)) + return background + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, __handle) + return result __all__ = [ diff --git a/omega_miya/plugins/su_self_sent/__init__.py b/omega_miya/plugins/su_self_sent/__init__.py new file mode 100644 index 00000000..cd666779 --- /dev/null +++ b/omega_miya/plugins/su_self_sent/__init__.py @@ -0,0 +1,100 @@ +""" +@Author : Ailitonia +@Date : 2021/05/31 21:14 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : go-cqhttp 适配专用, 用于人工同时登陆 bot 账号时将自己发送的消息转成 message 类型便于执行命令, + bot 账号发送命令前添加 !SU 即可将消息事件由 message_sent 转换为 group_message, 仅限群组中生效, + 为避免命令恶意执行, bot 不能为 superuser +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import re +from datetime import datetime +from nonebot import logger +from nonebot.plugin import on, CommandGroup +from nonebot.typing import T_State +from nonebot.message import handle_event +from nonebot.rule import to_me +from nonebot.permission import SUPERUSER +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.message import Message +from nonebot.adapters.cqhttp.event import Event, MessageEvent, GroupMessageEvent + + +SU_TAG: bool = False + +# 注册事件响应器 +Su = CommandGroup('Su', rule=to_me(), permission=SUPERUSER, priority=10, block=True) + +su_on = Su.command('on') +su_off = Su.command('off') + + +@su_on.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + global SU_TAG + SU_TAG = True + logger.info(f'Su: 特权命令已启用, 下一条!SU命令将以管理员身份执行') + await su_on.finish(f'特权命令已启用, 下一条!SU命令将以管理员身份执行') + + +@su_off.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + global SU_TAG + SU_TAG = False + logger.info(f'Su: 特权命令已禁用') + await su_off.finish(f'特权命令已禁用') + + +self_sent_msg_convertor = on( + type='message_sent', + priority=10, + block=False +) + + +@self_sent_msg_convertor.handle() +async def _handle(bot: Bot, event: Event, state: T_State): + self_id = event.dict().get('self_id', -1) + user_id = event.dict().get('user_id', -1) + if self_id == user_id and str(self_id) == bot.self_id and str(self_id) not in bot.config.superusers: + raw_message = event.dict().get('raw_message', '') + if str(raw_message).startswith('!SU'): + global SU_TAG + try: + if SU_TAG and list(bot.config.superusers): + user_id = int(list(bot.config.superusers)[0]) + raw_message = re.sub(r'^!SU', '', str(raw_message)).strip() + message = Message(raw_message) + time = event.dict().get('time', int(datetime.now().timestamp())) + sub_type = event.dict().get('sub_type', 'normal') + group_id = event.dict().get('group_id', -1) + message_type = event.dict().get('message_type', 'group') + message_id = event.dict().get('message_id', -1) + font = event.dict().get('font', 0) + sender = event.dict().get('sender', {'user_id': user_id}) + + new_event = GroupMessageEvent(**{ + 'time': time, + 'self_id': self_id, + 'post_type': 'message', + 'sub_type': sub_type, + 'user_id': user_id, + 'group_id': group_id, + 'message_type': message_type, + 'message_id': message_id, + 'message': message, + 'raw_message': raw_message, + 'font': font, + 'sender': sender + }) + + await handle_event(bot=bot, event=new_event) + except Exception as e: + logger.error(f'Self sent msg convertor convert an self_sent event failed, ' + f'error: {repr(e)}, event: {event}.') + finally: + SU_TAG = False + logger.info(f'Su: !SU命令已执行, SU_TAG已复位.') diff --git a/omega_miya/plugins/tencent_cloud/__init__.py b/omega_miya/plugins/tencent_cloud/__init__.py new file mode 100644 index 00000000..3ddff66a --- /dev/null +++ b/omega_miya/plugins/tencent_cloud/__init__.py @@ -0,0 +1,119 @@ +import re +from nonebot import MatcherGroup, logger, export +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import GroupMessageEvent +from nonebot.adapters.cqhttp.permission import GROUP +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, OmegaRules +from omega_miya.utils.tencent_cloud_api import TencentNLP, TencentTMT + + +# Custom plugin usage text +__plugin_name__ = 'TencentCloudCore' +__plugin_usage__ = r'''【TencentCloud API Support】 +腾讯云API插件 +测试中 + +**Permission** +Command & Lv.50 +or AuthNode + +**AuthNode** +basic + +**Usage** +/翻译''' + +# 声明本插件可配置的权限节点 +__plugin_auth_node__ = [ + 'tmt', + 'nlp' +] + +# Init plugin export +init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) + + +tencent_cloud = MatcherGroup( + type='message', + permission=GROUP, + priority=100, + block=False) + + +translate = tencent_cloud.on_command( + '翻译', + aliases={'translate'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='translate', + command=True, + level=30, + auth_node='tmt'), + priority=30, + block=True) + + +# 修改默认参数处理 +@translate.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip() + if not args: + await translate.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args + if state[state["_current_key"]] == '取消': + await translate.finish('操作已取消') + + +@translate.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip() + if not args: + pass + else: + state['content'] = args + + +@translate.got('content', prompt='请发送需要翻译的内容:') +async def handle_roll(bot: Bot, event: GroupMessageEvent, state: T_State): + content = state['content'] + translate_result = await TencentTMT().translate(source_text=content) + if translate_result.error: + await translate.finish('翻译失败了QAQ, 发生了意外的错误') + else: + await translate.finish(f"翻译结果:\n\n{translate_result.result.get('targettext')}") + + +nlp = tencent_cloud.on_message( + rule=OmegaRules.has_group_command_permission() & OmegaRules.has_level_or_node(30, 'tencent_cloud.nlp')) + + +@nlp.handle() +async def handle_nlp(bot: Bot, event: GroupMessageEvent, state: T_State): + arg = str(event.get_plaintext()).strip().lower() + + # 排除列表 + ignore_pattern = [ + re.compile(r'喵一个'), + re.compile(r'^今天'), + re.compile(r'[这那谁你我他她它]个?是[(什么)谁啥]') + ] + for pattern in ignore_pattern: + if re.search(pattern, arg): + await nlp.finish() + + # describe_entity实体查询 + if re.match(r'^(你?知道)?(.{1,32}?)的(.{1,32}?)是(什么|谁|啥)吗?[??]?$', arg): + item, attr = re.findall(r'^(你?知道)?(.{1,32}?)的(.{1,32}?)是(什么|谁|啥)吗?[??]?$', arg)[0][1:3] + res = await TencentNLP().describe_entity(entity_name=item, attr=attr) + if not res.error and res.result: + await nlp.finish(f'{item}的{attr}是{res.result}') + else: + logger.warning(f'nlp handling describe entity failed: {res.info}') + elif re.match(r'^(你?知道)?(.{1,32}?)是(什么|谁|啥)吗?[??]?$', arg): + item = re.findall(r'^(你?知道)?(.{1,32}?)是(什么|谁|啥)吗?[??]?$', arg)[0][1] + res = await TencentNLP().describe_entity(entity_name=item) + if not res.error and res.result: + await nlp.finish(str(res.result)) + else: + logger.warning(f'nlp handling describe entity failed: {res.info}') diff --git a/omega_miya/plugins/tencent_nlp/__init__.py b/omega_miya/plugins/tencent_nlp/__init__.py deleted file mode 100644 index cd9ac383..00000000 --- a/omega_miya/plugins/tencent_nlp/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -import re -from nonebot import MatcherGroup, logger -from nonebot.typing import T_State -from nonebot.adapters.cqhttp.bot import Bot -from nonebot.adapters.cqhttp.event import GroupMessageEvent -from nonebot.adapters.cqhttp.permission import GROUP -from omega_miya.utils.Omega_plugin_utils import has_command_permission -from omega_miya.utils.tencent_cloud_api import TencentNLP - -""" -腾讯云nlp插件 -测试中 -""" - -Nlp = MatcherGroup( - type='message', - rule=has_command_permission(), - permission=GROUP, - priority=100, - block=False) - - -nlp = Nlp.on_message() - - -@nlp.handle() -async def handle_nlp(bot: Bot, event: GroupMessageEvent, state: T_State): - arg = str(event.get_plaintext()).strip().lower() - - # 排除列表 - ignore_pattern = [ - re.compile(r'喵一个'), - re.compile(r'[这那]个?是[(什么)谁啥]') - ] - for pattern in ignore_pattern: - if re.search(pattern, arg): - await nlp.finish() - - # describe_entity实体查询 - if re.match(r'^(你?知道)?(.{1,32})是(什么|谁|啥)吗?[??]?$', arg): - item = re.findall(r'^(你?知道)?(.{1,32}?)是(什么|谁|啥)吗?[??]?$', arg)[0][1] - res = await TencentNLP().describe_entity(entity_name=item) - if not res.error and res.result: - await nlp.finish(str(res.result)) - else: - logger.warning(f'nlp handling describe entity failed: {res.info}') diff --git a/omega_miya/plugins/zhoushen_hime/__init__.py b/omega_miya/plugins/zhoushen_hime/__init__.py index 6ed32d42..9174b0c3 100644 --- a/omega_miya/plugins/zhoushen_hime/__init__.py +++ b/omega_miya/plugins/zhoushen_hime/__init__.py @@ -2,12 +2,15 @@ 要求go-cqhttp v0.9.40以上 """ import os -from nonebot import on_notice, export, logger +from nonebot import on_command, on_notice, export, logger from nonebot.typing import T_State +from nonebot.permission import SUPERUSER from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_OWNER from nonebot.adapters.cqhttp.message import MessageSegment, Message -from nonebot.adapters.cqhttp.event import GroupUploadNoticeEvent -from omega_miya.utils.Omega_plugin_utils import init_export, has_auth_node +from nonebot.adapters.cqhttp.event import GroupMessageEvent, GroupUploadNoticeEvent +from omega_miya.utils.Omega_plugin_utils import init_export, init_permission_state, OmegaRules +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBAuth, Result from .utils import ZhouChecker, download_file @@ -18,11 +21,16 @@ 检测群内上传文件并自动锤轴 仅限群聊使用 +**Permission** +Group only with +AuthNode + **AuthNode** basic **Usage** -配置AuthNode启用''' +**GroupAdmin and SuperUser Only** +/ZhouShenHime ''' # 声明本插件可配置的权限节点 __plugin_auth_node__ = [ @@ -32,11 +40,93 @@ # Init plugin export init_export(export(), __plugin_name__, __plugin_usage__, __plugin_auth_node__) - -ZhouShenHime = on_notice(rule=has_auth_node(__plugin_raw_name__, 'basic'), priority=100, block=False) - - -@ZhouShenHime.handle() +# 注册事件响应器 +zhoushen_hime_admin = on_command( + 'ZhouShenHime', + aliases={'zhoushenhime', '审轴姬', '审轴机'}, + # 使用run_preprocessor拦截权限管理, 在default_state初始化所需权限 + state=init_permission_state( + name='zhoushen_hime', + command=True, + level=10), + permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, + priority=10, + block=True) + + +# 修改默认参数处理 +@zhoushen_hime_admin.args_parser +async def parse(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + await zhoushen_hime_admin.reject('你似乎没有发送有效的参数呢QAQ, 请重新发送:') + state[state["_current_key"]] = args[0] + if state[state["_current_key"]] == '取消': + await zhoushen_hime_admin.finish('操作已取消') + + +@zhoushen_hime_admin.handle() +async def handle_first_receive(bot: Bot, event: GroupMessageEvent, state: T_State): + args = str(event.get_plaintext()).strip().lower().split() + if not args: + pass + elif args and len(args) == 1: + state['sub_command'] = args[0] + else: + await zhoushen_hime_admin.finish('参数错误QAQ') + + +@zhoushen_hime_admin.got('sub_command', prompt='执行操作?\n【ON/OFF】') +async def handle_sub_command_args(bot: Bot, event: GroupMessageEvent, state: T_State): + sub_command = state['sub_command'] + if sub_command not in ['on', 'off']: + await zhoushen_hime_admin.reject('没有这个选项哦, 请在【ON/OFF】中选择并重新发送, 取消命令请发送【取消】:') + + if sub_command == 'on': + _res = await zhoushen_hime_on(bot=bot, event=event, state=state) + elif sub_command == 'off': + _res = await zhoushen_hime_off(bot=bot, event=event, state=state) + else: + _res = Result.IntResult(error=True, info='Unknown error, except sub_command', result=-1) + + if _res.success(): + logger.info(f"设置自动审轴姬状态为 {sub_command} 成功, group_id: {event.group_id}, {_res.info}") + await zhoushen_hime_admin.finish(f'已设置自动审轴姬状态为 {sub_command}!') + else: + logger.error(f"设置自动审轴姬状态为 {sub_command} 失败, group_id: {event.group_id}, {_res.info}") + await zhoushen_hime_admin.finish(f'设置自动审轴姬状态失败了QAQ, 请稍后再试~') + + +async def zhoushen_hime_on(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=1, deny_tag=0, auth_info='启用自动审轴姬') + return result + + +async def zhoushen_hime_off(bot: Bot, event: GroupMessageEvent, state: T_State) -> Result.IntResult: + group_id = event.group_id + self_bot = DBBot(self_qq=int(bot.self_id)) + group = DBBotGroup(group_id=group_id, self_bot=self_bot) + group_exist = await group.exist() + if not group_exist: + return Result.IntResult(error=False, info='Group not exist', result=-1) + + auth_node = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=f'{__plugin_raw_name__}.basic') + result = await auth_node.set(allow_tag=0, deny_tag=1, auth_info='禁用自动审轴姬') + return result + + +zhoushen_hime = on_notice(rule=OmegaRules.has_auth_node(__plugin_raw_name__, 'basic'), priority=100, block=False) + + +@zhoushen_hime.handle() async def hime_handle(bot: Bot, event: GroupUploadNoticeEvent, state: T_State): file_name = event.file.name file_url = event.file.dict().get('url') @@ -44,23 +134,23 @@ async def hime_handle(bot: Bot, event: GroupUploadNoticeEvent, state: T_State): # 不响应自己上传的文件 if int(event.user_id) == int(bot.self_id): - await ZhouShenHime.finish() + await zhoushen_hime.finish() if file_name.split('.')[-1] not in ['ass', 'ASS']: - await ZhouShenHime.finish() + await zhoushen_hime.finish() # 只处理文件名中含"未校""待校""需校"的文件 if not any(key in file_name for key in ['未校', '待校', '需校']): - await ZhouShenHime.finish() + await zhoushen_hime.finish() dl_res = await download_file(url=file_url, file_name=file_name) if not dl_res.success(): logger.error(f'下载文件失败: {dl_res.info}') - await ZhouShenHime.finish() + await zhoushen_hime.finish() at_msg = MessageSegment.at(user_id=user_id) msg = f'{at_msg}你刚刚上传了一份轴呢, 让我来帮你看看吧!' - await ZhouShenHime.send(Message(msg)) + await zhoushen_hime.send(Message(msg)) file_path = os.path.abspath(dl_res.result) checker = ZhouChecker(file_path=file_path, flash_mode=True) @@ -69,15 +159,15 @@ async def hime_handle(bot: Bot, event: GroupUploadNoticeEvent, state: T_State): init_res = checker.init_file(auto_style=True) if not init_res.success(): logger.error(f'初始化时轴文件失败: {init_res.info}') - await ZhouShenHime.finish('出错了QAQ') + await zhoushen_hime.finish('审轴姬出错了QAQ') handle_res = checker.handle() if not handle_res.success(): logger.error(f'处理时轴文件失败: {handle_res.info}') - await ZhouShenHime.finish('出错了QAQ') + await zhoushen_hime.finish('审轴姬出错了QAQ') except Exception as e: logger.error(f'执行ZhouChecker时发生了意外的错误: {repr(e)}') - await ZhouShenHime.finish('出错了QAQ') + await zhoushen_hime.finish('审轴姬出错了QAQ') return output_txt_path = os.path.abspath(handle_res.result.get('output_txt_path')) @@ -93,7 +183,7 @@ async def hime_handle(bot: Bot, event: GroupUploadNoticeEvent, state: T_State): # 没有检查到错误的话就直接结束 if character_count + flash_count + overlap_count == 0: msg = f'看完了! 没有发现符号错误、疑问文本、叠轴和闪轴, 真棒~' - await ZhouShenHime.finish(msg) + await zhoushen_hime.finish(msg) try: group_file_info = await bot.call_api(api='get_group_root_files', group_id=event.group_id) @@ -118,9 +208,9 @@ async def hime_handle(bot: Bot, event: GroupUploadNoticeEvent, state: T_State): file=output_ass_path, name=output_ass_filename) except Exception as e: logger.error(f'上传结果时时发生了意外的错误: {repr(e)}') - await ZhouShenHime.finish('出错了QAQ') + await zhoushen_hime.finish('审轴姬出错了QAQ') return msg = f'看完了! 以下是结果:\n\n符号及疑问文本共{character_count}处\n' \ f'叠轴共{overlap_count}处\n闪轴共{flash_count}处\n\n锤轴结果已上传, 请参考修改哟~' - await ZhouShenHime.finish(msg) + await zhoushen_hime.finish(msg) diff --git a/omega_miya/plugins/zhoushen_hime/utils.py b/omega_miya/plugins/zhoushen_hime/utils.py index d2420a0e..1f831062 100644 --- a/omega_miya/plugins/zhoushen_hime/utils.py +++ b/omega_miya/plugins/zhoushen_hime/utils.py @@ -17,10 +17,10 @@ def __init__(self, *args): # 构造ass字幕类 class AssScriptLine(object): # 标记属性 - __Style: str = 'Style' - __Dialogue: str = 'Dialogue' - __Comment: str = 'Comment' - __Header: str = 'Header' + __STYLE: str = 'Style' + __DIALOGUE: str = 'Dialogue' + __COMMENT: str = 'Comment' + __HEADER: str = 'Header' # 为方便时间计算, 字幕时间起点以0点为基准 @classmethod @@ -140,16 +140,16 @@ def init(self) -> None: self.__raw_text = self.__raw_text.strip() # 判断类型 - if self.__raw_text.startswith(AssScriptLine.__Style): - self.__type = AssScriptLine.__Style + if self.__raw_text.startswith(self.__STYLE): + self.__type = self.__STYLE split_line = self.__raw_text.split(',', maxsplit=22) self.__style = split_line[0].split(':')[1].strip() self.__is_init = True - elif self.__raw_text.startswith(AssScriptLine.__Dialogue): - self.__type = AssScriptLine.__Dialogue + elif self.__raw_text.startswith(self.__DIALOGUE): + self.__type = self.__DIALOGUE split_line = self.__raw_text.split(',', maxsplit=9) self.__start_time = self.__time_handle(time=split_line[1]) @@ -167,8 +167,8 @@ def init(self) -> None: self.__is_init = True - elif self.__raw_text.startswith(AssScriptLine.__Comment): - self.__type = AssScriptLine.__Comment + elif self.__raw_text.startswith(self.__COMMENT): + self.__type = self.__COMMENT split_line = self.__raw_text.split(',', maxsplit=9) self.__start_time = self.__time_handle(time=split_line[1]) @@ -187,7 +187,7 @@ def init(self) -> None: self.__is_init = True else: - self.__type = AssScriptLine.__Header + self.__type = self.__HEADER self.__is_init = True @@ -198,7 +198,7 @@ def generate(self) -> str: if not self.__is_init: return self.raw_text - if self.__type in [AssScriptLine.__Header, AssScriptLine.__Style]: + if self.__type in [self.__HEADER, self.__STYLE]: return self.raw_text start_time = self.start_time.strftime('%H:%M:%S.%f')[:-4] @@ -229,7 +229,7 @@ def check_flash(self, threshold_time: int) -> Tuple[int, datetime.timedelta]: if not self.__is_init: return -1, datetime.timedelta(0) - if self.__type != AssScriptLine.__Dialogue: + if self.__type != self.__DIALOGUE: return -1, datetime.timedelta(0) threshold_duration = datetime.timedelta(microseconds=threshold_time * 1000) @@ -266,12 +266,12 @@ def __repr__(self): # 构造ass字幕event行工具类 class AssScriptLineTool(object): # 标记属性 - __Style: str = 'Style' - __Dialogue: str = 'Dialogue' - __Comment: str = 'Comment' - __Header: str = 'Header' + __STYLE: str = 'Style' + __DIALOGUE: str = 'Dialogue' + __COMMENT: str = 'Comment' + __HEADER: str = 'Header' - @ classmethod + @classmethod def check_continuous(cls, start_line: AssScriptLine, end_line: AssScriptLine, style_mode: bool) \ -> Tuple[int, datetime.timedelta]: """ @@ -290,7 +290,7 @@ def check_continuous(cls, start_line: AssScriptLine, end_line: AssScriptLine, st if not all([start_line.is_init, end_line.is_init]): return -1, datetime.timedelta(0) - if start_line.type != AssScriptLineTool.__Dialogue or end_line.type != AssScriptLineTool.__Dialogue: + if start_line.type != cls.__DIALOGUE or end_line.type != cls.__DIALOGUE: return -1, datetime.timedelta(0) lines_duration = \ @@ -306,7 +306,7 @@ def check_continuous(cls, start_line: AssScriptLine, end_line: AssScriptLine, st else: return 0, lines_duration - @ classmethod + @classmethod def check_overlap(cls, start_line: AssScriptLine, end_line: AssScriptLine, style_mode: bool) \ -> Tuple[int, datetime.timedelta]: """ @@ -325,7 +325,7 @@ def check_overlap(cls, start_line: AssScriptLine, end_line: AssScriptLine, style if not all([start_line.is_init, end_line.is_init]): return -1, datetime.timedelta(0) - if start_line.type != AssScriptLineTool.__Dialogue or end_line.type != AssScriptLineTool.__Dialogue: + if start_line.type != cls.__DIALOGUE or end_line.type != cls.__DIALOGUE: return -1, datetime.timedelta(0) if style_mode: @@ -341,7 +341,7 @@ def check_overlap(cls, start_line: AssScriptLine, end_line: AssScriptLine, style else: return 0, lines_duration - @ classmethod + @classmethod def check_flash(cls, start_line: AssScriptLine, end_line: AssScriptLine, threshold_time: int, style_mode: bool) -> Tuple[int, datetime.timedelta]: """ @@ -361,7 +361,7 @@ def check_flash(cls, start_line: AssScriptLine, end_line: AssScriptLine, if not all([start_line.is_init, end_line.is_init]): return -1, datetime.timedelta(0) - if start_line.type != AssScriptLineTool.__Dialogue or end_line.type != AssScriptLineTool.__Dialogue: + if start_line.type != cls.__DIALOGUE or end_line.type != cls.__DIALOGUE: return -1, datetime.timedelta(0) lines_duration = \ @@ -385,9 +385,9 @@ def check_flash(cls, start_line: AssScriptLine, end_line: AssScriptLine, # 构造ass字幕文件处理工具类 -class ZhouChecker(object): +class ZhouChecker(AssScriptLineTool): # 需要校对的关键词 - __proofreading_words = ['ong', '???', '???'] + __proofreading_words = ['???', '???'] # 要替换的标点, key为替换前, value为替换后 __punctuation_replace = { @@ -404,12 +404,14 @@ class ZhouChecker(object): '!': '!', '?': '?', ' ': ' ', + '[': '「', + ']': '」', '【': '「', '】': '」' } # 不知道咋换的标点 - __punctuation_ignore = ["'", '"', ','] + __punctuation_ignore = ["'", '"', ',', '/'] def __init__(self, file_path: str, single_threshold_time: int = 500, multi_threshold_time: int = 300, flash_mode: bool = False, style_mode: bool = False, fx_mode: bool = True): @@ -538,15 +540,15 @@ def handle(self) -> Result.DictResult: threshold_time=self.__single_threshold_time) # 检查连轴 - continuous, continuous_lines_duration = AssScriptLineTool.check_continuous( + continuous, continuous_lines_duration = self.check_continuous( start_line=start_line, end_line=end_line, style_mode=style_mode) # 检查叠轴 - overlap, overlap_duration = AssScriptLineTool.check_overlap( + overlap, overlap_duration = self.check_overlap( start_line=start_line, end_line=end_line, style_mode=style_mode) # 检查轴间闪轴 - multi_flash, multi_flash_lines_duration = AssScriptLineTool.check_flash( + multi_flash, multi_flash_lines_duration = self.check_flash( start_line=start_line, end_line=end_line, threshold_time=self.__multi_threshold_time, style_mode=style_mode) diff --git a/omega_miya/utils/Omega_Base/__init__.py b/omega_miya/utils/Omega_Base/__init__.py index 855c9333..d4c23743 100644 --- a/omega_miya/utils/Omega_Base/__init__.py +++ b/omega_miya/utils/Omega_Base/__init__.py @@ -3,24 +3,23 @@ 其他插件不得单独写入数据库操作逻辑 """ -from .database import DBTable from .class_result import Result -from .model import \ - DBUser, DBFriend, DBGroup, DBSkill, DBSubscription, DBDynamic, \ - DBPixivillust, DBPixivtag, DBPixivision, \ - DBEmail, DBEmailBox, DBHistory, DBAuth, DBCoolDownEvent, DBStatus +from .model import ( + DBUser, DBFriend, DBBot, DBBotGroup, DBGroup, DBSkill, DBSubscription, DBDynamic, + DBPixivUserArtwork, DBPixivillust, DBPixivision, DBEmail, DBEmailBox, DBHistory, DBAuth, DBCoolDownEvent, DBStatus) __all__ = [ - 'DBTable', 'DBUser', 'DBFriend', + 'DBBot', + 'DBBotGroup', 'DBGroup', 'DBSkill', 'DBSubscription', 'DBDynamic', + 'DBPixivUserArtwork', 'DBPixivillust', - 'DBPixivtag', 'DBPixivision', 'DBEmail', 'DBEmailBox', diff --git a/omega_miya/utils/Omega_Base/class_result.py b/omega_miya/utils/Omega_Base/class_result.py index 717f971f..1f81fe94 100644 --- a/omega_miya/utils/Omega_Base/class_result.py +++ b/omega_miya/utils/Omega_Base/class_result.py @@ -64,6 +64,13 @@ class TextListResult(BaseResult): def __repr__(self): return f'' + @dataclass + class TupleListResult(BaseResult): + result: List[tuple] + + def __repr__(self): + return f'' + @dataclass class DictListResult(BaseResult): result: List[dict] @@ -134,6 +141,13 @@ class BoolResult(BaseResult): def __repr__(self): return f'' + @dataclass + class BytesResult(BaseResult): + result: bytes + + def __repr__(self): + return f'' + @dataclass class AnyResult(BaseResult): result: Any diff --git a/omega_miya/utils/Omega_Base/database.py b/omega_miya/utils/Omega_Base/database.py index 8247e8eb..96603c62 100644 --- a/omega_miya/utils/Omega_Base/database.py +++ b/omega_miya/utils/Omega_Base/database.py @@ -6,7 +6,9 @@ from .tables import Base from .class_result import Result -global_config = nonebot.get_driver().config +driver = nonebot.get_driver() + +global_config = driver.config __DATABASE = 'mysql' __DB_DRIVER = 'aiomysql' __DB_USER = global_config.db_user @@ -32,7 +34,10 @@ sys.exit('创建数据库连接失败') +# 初始化化数据库 +@driver.on_startup async def database_init(): + nonebot.logger.opt(colors=True).info(f'正在初始化数据库......') try: # 初始化数据库结构 # conn is an instance of AsyncConnection @@ -43,17 +48,13 @@ async def database_init(): # where synchronous IO calls will be transparently translated for # await. await conn.run_sync(Base.metadata.create_all) - nonebot.logger.opt(colors=True).debug(f'初始化数据库...完成') + nonebot.logger.opt(colors=True).info(f'数据库初始化已完成.') except Exception as e: import sys nonebot.logger.opt(colors=True).critical(f'数据库初始化失败, error: {repr(e)}') sys.exit('数据库初始化失败') -# 初始化化数据库 -nonebot.get_driver().on_startup(database_init) - - class NBdb(object): def __init__(self): # expire_on_commit=False will prevent attributes from being expired @@ -68,6 +69,10 @@ def get_async_session(self): class DBTable(object): + """ + 已弃用, 保留相关代码仅供参考 + 任何情况下请直接调用 model 中相关类, 不要使用本类构造实例 + """ def __init__(self, table_name): self.__tables = Base self.table_name = table_name @@ -119,6 +124,5 @@ async def list_col_with_condition(self, col_name, condition_col_name, condition) __all__ = [ - 'NBdb', - 'DBTable' + 'NBdb' ] diff --git a/omega_miya/utils/Omega_Base/model/__init__.py b/omega_miya/utils/Omega_Base/model/__init__.py index 11be0c66..ee17e7bb 100644 --- a/omega_miya/utils/Omega_Base/model/__init__.py +++ b/omega_miya/utils/Omega_Base/model/__init__.py @@ -1,13 +1,15 @@ from .auth import DBAuth from .bilidynamic import DBDynamic +from .bot_group import DBBotGroup +from .bot_self import DBBot from .cooldown import DBCoolDownEvent from .friend import DBFriend from .group import DBGroup from .history import DBHistory from .mail import DBEmail, DBEmailBox +from .pixiv_user_artwork import DBPixivUserArtwork from .pixivillust import DBPixivillust from .pixivision import DBPixivision -from .pixivtag import DBPixivtag from .skill import DBSkill from .subscription import DBSubscription from .user import DBUser @@ -16,15 +18,17 @@ __all__ = [ 'DBAuth', 'DBDynamic', + 'DBBotGroup', + 'DBBot', 'DBCoolDownEvent', 'DBFriend', 'DBGroup', 'DBHistory', 'DBEmail', 'DBEmailBox', + 'DBPixivUserArtwork', 'DBPixivillust', 'DBPixivision', - 'DBPixivtag', 'DBSkill', 'DBSubscription', 'DBUser', diff --git a/omega_miya/utils/Omega_Base/model/auth.py b/omega_miya/utils/Omega_Base/model/auth.py index 4d4b0b66..485dde5e 100644 --- a/omega_miya/utils/Omega_Base/model/auth.py +++ b/omega_miya/utils/Omega_Base/model/auth.py @@ -1,35 +1,45 @@ from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import AuthUser, AuthGroup, User, Group -from .user import DBUser -from .group import DBGroup +from omega_miya.utils.Omega_Base.tables import AuthUser, AuthGroup, User, Friends, Group, BotGroup +from .friend import DBFriend +from .bot_group import DBBotGroup +from .bot_self import DBBot from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound class DBAuth(object): - def __init__(self, auth_id: int, auth_type: str, auth_node: str): + def __init__(self, self_bot: DBBot, auth_id: int, auth_type: str, auth_node: str): """ + :param self_bot: 对应DBBot对象 :param auth_id: 请求授权id, 用户qq号或群组群号 :param auth_type: user: 用户授权 group: 群组授权 :param auth_node: 授权节点 """ + self.self_bot = self_bot self.auth_id = auth_id self.auth_type = auth_type self.auth_node = auth_node async def id(self) -> Result.IntResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser.id).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser.id). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -37,8 +47,11 @@ async def id(self) -> Result.IntResult: result = Result.IntResult(error=False, info='Success', result=auth_table_id) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup.id).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup.id). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -59,6 +72,10 @@ async def exist(self) -> bool: return result.success() async def set(self, allow_tag: int, deny_tag: int, auth_info: str = None) -> Result.IntResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -66,8 +83,11 @@ async def set(self, allow_tag: int, deny_tag: int, auth_info: str = None) -> Res try: if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -79,8 +99,11 @@ async def set(self, allow_tag: int, deny_tag: int, auth_info: str = None) -> Res result = Result.IntResult(error=False, info='Success upgraded', result=0) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -94,23 +117,23 @@ async def set(self, allow_tag: int, deny_tag: int, auth_info: str = None) -> Res result = Result.IntResult(error=True, info='Auth type error', result=-1) except NoResultFound: if self.auth_type == 'user': - user = DBUser(user_id=self.auth_id) - user_id_result = await user.id() - if user_id_result.error: - result = Result.IntResult(error=True, info='User not exist', result=-1) + friend = DBFriend(user_id=self.auth_id, self_bot=self.self_bot) + friend_id_result = await friend.friend_id() + if friend_id_result.error: + result = Result.IntResult(error=True, info='Friend not exist', result=-1) else: - auth = AuthUser(user_id=user_id_result.result, auth_node=self.auth_node, + auth = AuthUser(user_id=friend_id_result.result, auth_node=self.auth_node, allow_tag=allow_tag, deny_tag=deny_tag, auth_info=auth_info, created_at=datetime.now()) session.add(auth) result = Result.IntResult(error=False, info='Success set', result=0) elif self.auth_type == 'group': - group = DBGroup(group_id=self.auth_id) - group_id_result = await group.id() - if group_id_result.error: + bot_group = DBBotGroup(group_id=self.auth_id, self_bot=self.self_bot) + bot_group_id_result = await bot_group.bot_group_id() + if bot_group_id_result.error: result = Result.IntResult(error=True, info='Group not exist', result=-1) else: - auth = AuthGroup(group_id=group_id_result.result, auth_node=self.auth_node, + auth = AuthGroup(group_id=bot_group_id_result.result, auth_node=self.auth_node, allow_tag=allow_tag, deny_tag=deny_tag, auth_info=auth_info, created_at=datetime.now()) session.add(auth) @@ -127,14 +150,21 @@ async def set(self, allow_tag: int, deny_tag: int, auth_info: str = None) -> Res return result async def allow_tag(self) -> Result.IntResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser.allow_tag).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser.allow_tag). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -142,8 +172,11 @@ async def allow_tag(self) -> Result.IntResult: result = Result.IntResult(error=False, info='Success', result=allow_tag) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup.allow_tag).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup.allow_tag). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -160,14 +193,21 @@ async def allow_tag(self) -> Result.IntResult: return result async def deny_tag(self) -> Result.IntResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser.deny_tag).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser.deny_tag). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -175,8 +215,11 @@ async def deny_tag(self) -> Result.IntResult: result = Result.IntResult(error=False, info='Success', result=deny_tag) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup.deny_tag).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup.deny_tag). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -193,14 +236,21 @@ async def deny_tag(self) -> Result.IntResult: return result async def tags_info(self) -> Result.IntTupleResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntTupleResult(error=True, info='Bot not exist', result=(-1, -1)) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser.allow_tag, AuthUser.deny_tag).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser.allow_tag, AuthUser.deny_tag). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -208,8 +258,11 @@ async def tags_info(self) -> Result.IntTupleResult: result = Result.IntTupleResult(error=False, info='Success', result=(res[0], res[1])) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup.allow_tag, AuthGroup.deny_tag).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup.allow_tag, AuthGroup.deny_tag). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -226,14 +279,21 @@ async def tags_info(self) -> Result.IntTupleResult: return result async def delete(self) -> Result.IntResult: + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): if self.auth_type == 'user': session_result = await session.execute( - select(AuthUser).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == self.auth_id). where(AuthUser.auth_node == self.auth_node) ) @@ -242,8 +302,11 @@ async def delete(self) -> Result.IntResult: result = Result.IntResult(error=False, info='Success', result=0) elif self.auth_type == 'group': session_result = await session.execute( - select(AuthGroup).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == self.auth_id). where(AuthGroup.auth_node == self.auth_node) ) @@ -265,23 +328,33 @@ async def delete(self) -> Result.IntResult: return result @classmethod - async def list(cls, auth_type: str, auth_id: int) -> Result.ListResult: + async def list(cls, auth_type: str, auth_id: int, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: if auth_type == 'user': session_result = await session.execute( - select(AuthUser.auth_node, AuthUser.allow_tag, AuthUser.deny_tag).join(User). - where(AuthUser.user_id == User.id). + select(AuthUser.auth_node, AuthUser.allow_tag, AuthUser.deny_tag). + join(Friends).join(User). + where(AuthUser.user_id == Friends.id). + where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(User.qq == auth_id) ) auth_node_list = [(x[0], x[1], x[2]) for x in session_result.all()] result = Result.ListResult(error=False, info='Success', result=auth_node_list) elif auth_type == 'group': session_result = await session.execute( - select(AuthGroup.auth_node, AuthGroup.allow_tag, AuthGroup.deny_tag).join(Group). - where(AuthGroup.group_id == Group.id). + select(AuthGroup.auth_node, AuthGroup.allow_tag, AuthGroup.deny_tag). + join(BotGroup).join(Group). + where(AuthGroup.group_id == BotGroup.id). + where(BotGroup.group_id == Group.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). where(Group.group_id == auth_id) ) auth_node_list = [(x[0], x[1], x[2]) for x in session_result.all()] diff --git a/omega_miya/utils/Omega_Base/model/bilidynamic.py b/omega_miya/utils/Omega_Base/model/bilidynamic.py index 6813d75b..a32439d8 100644 --- a/omega_miya/utils/Omega_Base/model/bilidynamic.py +++ b/omega_miya/utils/Omega_Base/model/bilidynamic.py @@ -64,3 +64,35 @@ async def add(self, dynamic_type: int, content: str) -> Result.IntResult: await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + @classmethod + async def list_all_dynamic(cls) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Bilidynamic.dynamic_id).order_by(Bilidynamic.dynamic_id) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_dynamic_by_uid(cls, uid: int) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Bilidynamic.dynamic_id). + where(Bilidynamic.uid == uid). + order_by(Bilidynamic.dynamic_id) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/bot_group.py b/omega_miya/utils/Omega_Base/model/bot_group.py new file mode 100644 index 00000000..21d3916c --- /dev/null +++ b/omega_miya/utils/Omega_Base/model/bot_group.py @@ -0,0 +1,1070 @@ +from omega_miya.utils.Omega_Base.database import NBdb +from omega_miya.utils.Omega_Base.class_result import Result +from omega_miya.utils.Omega_Base.tables import \ + User, Group, BotGroup, UserGroup, Vacation, Skill, UserSkill, \ + Subscription, GroupSub, GroupSetting, EmailBox, GroupEmailBox +from .user import DBUser +from .skill import DBSkill +from .group import DBGroup +from .bot_self import DBBot +from .subscription import DBSubscription +from .mail import DBEmailBox +from datetime import datetime +from sqlalchemy.future import select +from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound + + +class DBBotGroup(DBGroup): + def __init__(self, group_id: int, self_bot: DBBot): + super().__init__(group_id) + self.self_bot = self_bot + + @classmethod + async def list_exist_bot_groups(cls, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(Group.group_id). + join(BotGroup). + where(Group.id == BotGroup.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + exist_groupss = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=exist_groupss) + except NoResultFound: + result = Result.ListResult(error=True, info='NoResultFound', result=[]) + except MultipleResultsFound: + result = Result.ListResult(error=True, info='MultipleResultsFound', result=[]) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_exist_bot_groups_by_notice_permissions( + cls, notice_permissions: int, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(Group.group_id). + join(BotGroup). + where(Group.id == BotGroup.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.notice_permissions == notice_permissions) + ) + exist_friends = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=exist_friends) + except NoResultFound: + result = Result.ListResult(error=True, info='NoResultFound', result=[]) + except MultipleResultsFound: + result = Result.ListResult(error=True, info='MultipleResultsFound', result=[]) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_exist_bot_groups_by_command_permissions( + cls, command_permissions: int, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(Group.group_id). + join(BotGroup). + where(Group.id == BotGroup.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.command_permissions == command_permissions) + ) + exist_friends = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=exist_friends) + except NoResultFound: + result = Result.ListResult(error=True, info='NoResultFound', result=[]) + except MultipleResultsFound: + result = Result.ListResult(error=True, info='MultipleResultsFound', result=[]) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_exist_bot_groups_by_permission_level( + cls, permission_level: int, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(Group.group_id). + join(BotGroup). + where(Group.id == BotGroup.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.permission_level >= permission_level) + ) + exist_friends = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=exist_friends) + except NoResultFound: + result = Result.ListResult(error=True, info='NoResultFound', result=[]) + except MultipleResultsFound: + result = Result.ListResult(error=True, info='MultipleResultsFound', result=[]) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + async def bot_group_id(self) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(BotGroup.id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.group_id == group_id_result.result) + ) + bot_group_table_id = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=bot_group_table_id) + except NoResultFound: + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def exist(self) -> bool: + result = await self.bot_group_id() + return result.success() + + async def memo(self) -> Result.TextResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.TextResult(error=True, info='Group not exist', result='') + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.TextResult(error=True, info='Bot not exist', result='') + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup.group_memo). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.group_id == group_id_result.result) + ) + group_memo = session_result.scalar_one() + result = Result.TextResult(error=False, info='Success', result=group_memo) + except NoResultFound: + result = Result.TextResult(error=True, info='NoResultFound', result='') + except MultipleResultsFound: + result = Result.TextResult(error=True, info='MultipleResultsFound', result='') + except Exception as e: + result = Result.TextResult(error=True, info=repr(e), result='') + return result + + async def set_bot_group(self, group_memo: str = None) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + # 处理群备注过长 + if not group_memo: + pass + elif len(group_memo) > 64: + group_memo = group_memo[:63] + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + exist_group = session_result.scalar_one() + if group_memo: + exist_group.group_memo = group_memo + exist_group.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + new_group = BotGroup(group_id=group_id_result.result, bot_self_id=self_bot_id_result.result, + notice_permissions=0, command_permissions=0, permission_level=0, + group_memo=group_memo, created_at=datetime.now()) + session.add(new_group) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def del_bot_group(self) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + # 删除群组表中该群组 + session_result = await session.execute( + select(BotGroup). + where(BotGroup.id == bot_group_id_result.result) + ) + exist_group = session_result.scalar_one() + await session.delete(exist_group) + await session.commit() + result = Result.IntResult(error=False, info='Success Delete', result=0) + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def member_list(self) -> Result.TupleListResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(User.qq, UserGroup.user_group_nickname). + join(UserGroup). + where(User.id == UserGroup.user_id). + where(UserGroup.group_id == bot_group_id_result.result) + ) + res = [(x[0], x[1]) for x in session_result.all()] + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def member_add(self, user: DBUser, user_group_nickname: str) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + user_id_result = await user.id() + if user_id_result.error: + return Result.IntResult(error=True, info='User not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + # 查询成员-群组表中用户-群关系 + try: + # 用户-群关系已存在, 更新用户群昵称 + session_result = await session.execute( + select(UserGroup). + where(UserGroup.user_id == user_id_result.result). + where(UserGroup.group_id == bot_group_id_result.result) + ) + exist_user = session_result.scalar_one() + exist_user.user_group_nickname = user_group_nickname + exist_user.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + # 不存在关系则添加新成员 + new_user = UserGroup(user_id=user_id_result.result, group_id=bot_group_id_result.result, + user_group_nickname=user_group_nickname, created_at=datetime.now()) + session.add(new_user) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def member_del(self, user: DBUser) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + user_id_result = await user.id() + if user_id_result.error: + return Result.IntResult(error=True, info='User not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(UserGroup). + where(UserGroup.user_id == user_id_result.result). + where(UserGroup.group_id == bot_group_id_result.result) + ) + exist_user = session_result.scalar_one() + await session.delete(exist_user) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def member_clear(self) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(UserGroup).where(UserGroup.group_id == bot_group_id_result.result) + ) + for exist_user in session_result.scalars().all(): + await session.delete(exist_user) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def permission_reset(self) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(BotGroup). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + exist_group = session_result.scalar_one() + exist_group.notice_permissions = 0 + exist_group.command_permissions = 0 + exist_group.permission_level = 0 + exist_group.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + await session.commit() + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def permission_set(self, notice: int = 0, command: int = 0, level: int = 0) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(BotGroup). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + exist_group = session_result.scalar_one() + exist_group.notice_permissions = notice + exist_group.command_permissions = command + exist_group.permission_level = level + exist_group.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + await session.commit() + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def permission_info(self) -> Result.IntTupleResult: + """ + :return: Result: Tuple[Notice_permission, Command_permission, Permission_level] + """ + group_id_result = await self.id() + if group_id_result.error: + return Result.IntTupleResult(error=True, info='Group not exist', result=(-1, -1, -1)) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntTupleResult(error=True, info='Bot not exist', result=(-1, -1, -1)) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup.notice_permissions, BotGroup.command_permissions, BotGroup.permission_level). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + notice, command, level = session_result.one() + result = Result.IntTupleResult(error=False, info='Success', result=(notice, command, level)) + except NoResultFound: + result = Result.IntTupleResult(error=True, info='NoResultFound', result=(-1, -1, -1)) + except MultipleResultsFound: + result = Result.IntTupleResult(error=True, info='MultipleResultsFound', result=(-1, -1, -1)) + except Exception as e: + result = Result.IntTupleResult(error=True, info=repr(e), result=(-1, -1, -1)) + return result + + async def permission_notice(self) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup.notice_permissions). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + res = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def permission_command(self) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup.command_permissions). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + res = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def permission_level(self) -> Result.IntResult: + group_id_result = await self.id() + if group_id_result.error: + return Result.IntResult(error=True, info='Group not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(BotGroup.permission_level). + where(BotGroup.group_id == group_id_result.result). + where(BotGroup.bot_self_id == self_bot_id_result.result) + ) + res = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def idle_member_list(self) -> Result.ListResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.ListResult(error=True, info='BotGroup not exist', result=[]) + + res = [] + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + # 查询该群组中所有没有假的人 + session_result = await session.execute( + select(User.id, UserGroup.user_group_nickname). + join(Vacation).join(UserGroup). + where(User.id == Vacation.user_id). + where(User.id == UserGroup.user_id). + where(Vacation.status == 0). + where(UserGroup.group_id == bot_group_id_result.result) + ) + user_res = [(x[0], x[1]) for x in session_result.all()] + # 查对应每个人的技能 + for user_id, nickname in user_res: + session_result = await session.execute( + select(Skill.name).join(UserSkill).join(User). + where(Skill.id == UserSkill.skill_id). + where(UserSkill.user_id == User.id). + where(User.id == user_id) + ) + user_skill_res = [x for x in session_result.scalars().all()] + if user_skill_res: + user_skill_text = '/'.join(user_skill_res) + else: + user_skill_text = '暂无技能' + res.append((nickname, user_skill_text)) + result = Result.ListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + async def idle_skill_list(self, skill: DBSkill) -> Result.TupleListResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + skill_id_result = await skill.id() + if skill_id_result.error: + return Result.TupleListResult(error=True, info='Skill not exist', result=[]) + + res = [] + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + # 查询这这个技能有那些人会 + session_result = await session.execute( + select(User.id, UserGroup.user_group_nickname). + join(UserSkill).join(UserGroup). + where(User.id == UserSkill.user_id). + where(User.id == UserGroup.user_id). + where(UserSkill.skill_id == skill_id_result.result). + where(UserGroup.group_id == bot_group_id_result.result) + ) + user_res = [(x[0], x[1]) for x in session_result.all()] + # 查这个人是不是空闲 + for user_id, nickname in user_res: + session_result = await session.execute( + select(Vacation.status).where(Vacation.user_id == user_id) + ) + # 如果空闲则把这个人昵称放进结果列表里面 + if session_result.scalar_one() == 0: + res.append(nickname) + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def vacation_member_list(self) -> Result.TupleListResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + # 查询所有没有假的人 + session_result = await session.execute( + select(UserGroup.user_group_nickname, Vacation.stop_at). + select_from(UserGroup).join(User). + where(UserGroup.user_id == User.id). + where(User.id == Vacation.user_id). + where(Vacation.status == 1). + where(UserGroup.group_id == bot_group_id_result.result) + ) + res = [(x[0], x[1]) for x in session_result.all()] + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def init_member_status(self) -> Result.IntResult: + member_list_res = await self.member_list() + for user_qq, nickname in member_list_res.result: + user = DBUser(user_id=user_qq) + user_status_res = await user.status() + if user_status_res.error: + await user.status_set(status=0) + return Result.IntResult(error=False, info='ignore', result=0) + + async def subscription_list(self) -> Result.TupleListResult: + """ + :return: Result: List[Tuple[sub_type, sub_id, up_name]] + """ + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Subscription.sub_type, Subscription.sub_id, Subscription.up_name). + join(GroupSub). + where(Subscription.id == GroupSub.sub_id). + where(GroupSub.group_id == bot_group_id_result.result) + ) + res = [(x[0], x[1], x[2]) for x in session_result.all()] + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def subscription_list_by_type(self, sub_type: int) -> Result.TupleListResult: + """ + :param sub_type: 订阅类型 + :return: Result: List[Tuple[sub_id, up_name]] + """ + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Subscription.sub_id, Subscription.up_name). + join(GroupSub). + where(Subscription.sub_type == sub_type). + where(Subscription.id == GroupSub.sub_id). + where(GroupSub.group_id == bot_group_id_result.result) + ) + res = [(x[0], x[1]) for x in session_result.all()] + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def subscription_add(self, sub: DBSubscription, group_sub_info: str = None) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + sub_id_result = await sub.id() + if sub_id_result.error: + return Result.IntResult(error=True, info='Subscription not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(GroupSub). + where(GroupSub.group_id == bot_group_id_result.result). + where(GroupSub.sub_id == sub_id_result.result) + ) + # 订阅关系已存在, 更新信息 + exist_subscription = session_result.scalar_one() + exist_subscription.group_sub_info = group_sub_info + exist_subscription.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + subscription = GroupSub(sub_id=sub_id_result.result, group_id=bot_group_id_result.result, + group_sub_info=group_sub_info, created_at=datetime.now()) + session.add(subscription) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def subscription_del(self, sub: DBSubscription) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + sub_id_result = await sub.id() + if sub_id_result.error: + return Result.IntResult(error=True, info='Subscription not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupSub). + where(GroupSub.group_id == bot_group_id_result.result). + where(GroupSub.sub_id == sub_id_result.result) + ) + exist_subscription = session_result.scalar_one() + await session.delete(exist_subscription) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def subscription_clear(self) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupSub).where(GroupSub.group_id == bot_group_id_result.result) + ) + for exist_group_sub in session_result.scalars().all(): + await session.delete(exist_group_sub) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def subscription_clear_by_type(self, sub_type: int) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupSub).join(Subscription). + where(GroupSub.sub_id == Subscription.id). + where(Subscription.sub_type == sub_type). + where(GroupSub.group_id == bot_group_id_result.result) + ) + for exist_group_sub in session_result.scalars().all(): + await session.delete(exist_group_sub) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def mailbox_list(self) -> Result.ListResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.ListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(EmailBox.address). + join(GroupEmailBox). + where(EmailBox.id == GroupEmailBox.email_box_id). + where(GroupEmailBox.group_id == bot_group_id_result.result) + ) + res = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + async def mailbox_add(self, mailbox: DBEmailBox, mailbox_info: str = None) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + mailbox_id_result = await mailbox.id() + if mailbox_id_result.error: + return Result.IntResult(error=True, info='Mailbox not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(GroupEmailBox). + where(GroupEmailBox.group_id == bot_group_id_result.result). + where(GroupEmailBox.email_box_id == mailbox_id_result.result) + ) + # 群邮箱已存在, 更新信息 + exist_mailbox = session_result.scalar_one() + exist_mailbox.box_info = mailbox_info + exist_mailbox.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + new_mailbox = GroupEmailBox(email_box_id=mailbox_id_result.result, + group_id=bot_group_id_result.result, + box_info=mailbox_info, created_at=datetime.now()) + session.add(new_mailbox) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def mailbox_del(self, mailbox: DBEmailBox) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + mailbox_id_result = await mailbox.id() + if mailbox_id_result.error: + return Result.IntResult(error=True, info='Mailbox not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupEmailBox). + where(GroupEmailBox.group_id == bot_group_id_result.result). + where(GroupEmailBox.email_box_id == mailbox_id_result.result) + ) + exist_mailbox = session_result.scalar_one() + await session.delete(exist_mailbox) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def mailbox_clear(self) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupEmailBox).where(GroupEmailBox.group_id == bot_group_id_result.result) + ) + for exist_mailbox in session_result.scalars().all(): + await session.delete(exist_mailbox) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def setting_list(self) -> Result.TupleListResult: + """ + :return: Result: List[Tuple[setting_name, main_config, secondary_config, extra_config]] + """ + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TupleListResult(error=True, info='BotGroup not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(GroupSetting.setting_name, GroupSetting.main_config, + GroupSetting.secondary_config, GroupSetting.extra_config). + where(GroupSetting.group_id == bot_group_id_result.result) + ) + res = [(x[0], x[1], x[2], x[3]) for x in session_result.all()] + result = Result.TupleListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TupleListResult(error=True, info=repr(e), result=[]) + return result + + async def setting_get(self, setting_name: str) -> Result.TextTupleResult: + """ + :param setting_name: 配置名称 + :return: Result: Tuple[main_config, secondary_config, extra_config] + """ + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.TextTupleResult(error=True, info='BotGroup not exist', result=('', '', '')) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(GroupSetting.main_config, GroupSetting.secondary_config, GroupSetting.extra_config). + where(GroupSetting.setting_name == setting_name). + where(GroupSetting.group_id == bot_group_id_result.result) + ) + main, second, extra = session_result.one() + result = Result.TextTupleResult(error=False, info='Success', result=(main, second, extra)) + except NoResultFound: + result = Result.TextTupleResult(error=True, info='NoResultFound', result=('', '', '')) + except MultipleResultsFound: + result = Result.TextTupleResult(error=True, info='MultipleResultsFound', result=('', '', '')) + except Exception as e: + result = Result.TextTupleResult(error=True, info=repr(e), result=('', '', '')) + return result + + async def setting_set( + self, + setting_name: str, + main_config: str, + *, + secondary_config: str = 'None', + extra_config: str = 'None', + setting_info: str = 'None') -> Result.IntResult: + + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(GroupSetting). + where(GroupSetting.setting_name == setting_name). + where(GroupSetting.group_id == bot_group_id_result.result) + ) + # 已存在, 更新信息 + exist_setting = session_result.scalar_one() + exist_setting.main_config = main_config + exist_setting.secondary_config = secondary_config + exist_setting.extra_config = extra_config + exist_setting.setting_info = setting_info + exist_setting.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + new_setting = GroupSetting(group_id=bot_group_id_result.result, setting_name=setting_name, + main_config=main_config, secondary_config=secondary_config, + extra_config=extra_config, setting_info=setting_info, + created_at=datetime.now()) + session.add(new_setting) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def setting_del(self, setting_name: str) -> Result.IntResult: + bot_group_id_result = await self.bot_group_id() + if bot_group_id_result.error: + return Result.IntResult(error=True, info='BotGroup not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(GroupSetting). + where(GroupSetting.setting_name == setting_name). + where(GroupSetting.group_id == bot_group_id_result.result) + ) + exist_setting = session_result.scalar_one() + await session.delete(exist_setting) + await session.commit() + result = Result.IntResult(error=False, info='Success', result=0) + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result diff --git a/omega_miya/utils/Omega_Base/model/pixivtag.py b/omega_miya/utils/Omega_Base/model/bot_self.py similarity index 54% rename from omega_miya/utils/Omega_Base/model/pixivtag.py rename to omega_miya/utils/Omega_Base/model/bot_self.py index ed045840..9d41aa39 100644 --- a/omega_miya/utils/Omega_Base/model/pixivtag.py +++ b/omega_miya/utils/Omega_Base/model/bot_self.py @@ -1,14 +1,24 @@ +""" +@Author : Ailitonia +@Date : 2021/05/23 19:32 +@FileName : bot_self.py +@Project : nonebot2_miya +@Description : BotSelf Table Model +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import PixivTag, Pixiv, PixivT2I +from omega_miya.utils.Omega_Base.tables import BotSelf from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound -class DBPixivtag(object): - def __init__(self, tagname: str): - self.tagname = tagname +class DBBot(object): + def __init__(self, self_qq: int): + self.self_qq = self_qq async def id(self) -> Result.IntResult: async_session = NBdb().get_async_session() @@ -16,10 +26,10 @@ async def id(self) -> Result.IntResult: async with session.begin(): try: session_result = await session.execute( - select(PixivTag.id).where(PixivTag.tagname == self.tagname) + select(BotSelf.id).where(BotSelf.self_qq == self.self_qq) ) - pixivtag_table_id = session_result.scalar_one() - result = Result.IntResult(error=False, info='Success', result=pixivtag_table_id) + bot_table_id = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=bot_table_id) except NoResultFound: result = Result.IntResult(error=True, info='NoResultFound', result=-1) except MultipleResultsFound: @@ -32,20 +42,27 @@ async def exist(self) -> bool: result = await self.id() return result.success() - async def add(self) -> Result.IntResult: + async def upgrade(self, status: int = 0, info: str = None) -> Result.IntResult: async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): try: + # 已存在则更新表中已有信息 session_result = await session.execute( - select(PixivTag).where(PixivTag.tagname == self.tagname) + select(BotSelf).where(BotSelf.self_qq == self.self_qq) ) - exist_pixivtag = session_result.scalar_one() - result = Result.IntResult(error=False, info='pixivtag exist', result=0) + exist_bot = session_result.scalar_one() + exist_bot.status = status + if info: + exist_bot.info = info + exist_bot.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: - new_tag = PixivTag(tagname=self.tagname, created_at=datetime.now()) - session.add(new_tag) + # 不存在则在表中添加新信息 + new_bot = BotSelf(self_qq=self.self_qq, status=status, info=info, + created_at=datetime.now()) + session.add(new_bot) result = Result.IntResult(error=False, info='Success added', result=0) await session.commit() except MultipleResultsFound: @@ -55,21 +72,3 @@ async def add(self) -> Result.IntResult: await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result - - async def list_illust(self, nsfw_tag: int) -> Result.ListResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Pixiv.pid).join(PixivT2I).join(PixivTag). - where(Pixiv.id == PixivT2I.illust_id). - where(PixivT2I.tag_id == PixivTag.id). - where(Pixiv.nsfw_tag == nsfw_tag). - where(PixivTag.tagname.ilike(f'%{self.tagname}%')) - ) - tag_pid_list = [x for x in session_result.scalars().all()] - result = Result.ListResult(error=False, info='Success', result=tag_pid_list) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result diff --git a/omega_miya/utils/Omega_Base/model/cooldown.py b/omega_miya/utils/Omega_Base/model/cooldown.py index 3157b2dd..687210f2 100644 --- a/omega_miya/utils/Omega_Base/model/cooldown.py +++ b/omega_miya/utils/Omega_Base/model/cooldown.py @@ -9,6 +9,11 @@ class DBCoolDownEvent(object): @classmethod async def add_global_cool_down_event(cls, stop_at: datetime, description: str = None) -> Result.IntResult: + """ + :return: + result = 0: Success + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -39,6 +44,13 @@ async def add_global_cool_down_event(cls, stop_at: datetime, description: str = @classmethod async def check_global_cool_down_event(cls) -> Result.IntResult: + """ + :return: + result = 2: Success with CoolDown Event expired + result = 1: Success with CoolDown Event exist + result = 0: Success with CoolDown Event not found + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): @@ -49,7 +61,10 @@ async def check_global_cool_down_event(cls) -> Result.IntResult: ) event = session_result.scalar_one() stop_at = event.stop_at - result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) + if datetime.now() > stop_at: + result = Result.IntResult(error=False, info='Success, CoolDown expired', result=2) + else: + result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) except NoResultFound: result = Result.IntResult(error=False, info='NoResultFound', result=0) except MultipleResultsFound: @@ -61,6 +76,11 @@ async def check_global_cool_down_event(cls) -> Result.IntResult: @classmethod async def add_plugin_cool_down_event( cls, plugin: str, stop_at: datetime, description: str = None) -> Result.IntResult: + """ + :return: + result = 0: Success + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -93,6 +113,13 @@ async def add_plugin_cool_down_event( @classmethod async def check_plugin_cool_down_event(cls, plugin: str) -> Result.IntResult: + """ + :return: + result = 2: Success with CoolDown Event expired + result = 1: Success with CoolDown Event exist + result = 0: Success with CoolDown Event not found + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): @@ -104,7 +131,10 @@ async def check_plugin_cool_down_event(cls, plugin: str) -> Result.IntResult: ) event = session_result.scalar_one() stop_at = event.stop_at - result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) + if datetime.now() > stop_at: + result = Result.IntResult(error=False, info='Success, CoolDown expired', result=2) + else: + result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) except NoResultFound: result = Result.IntResult(error=False, info='NoResultFound', result=0) except MultipleResultsFound: @@ -116,6 +146,11 @@ async def check_plugin_cool_down_event(cls, plugin: str) -> Result.IntResult: @classmethod async def add_group_cool_down_event( cls, plugin: str, group_id: int, stop_at: datetime, description: str = None) -> Result.IntResult: + """ + :return: + result = 0: Success + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -149,6 +184,13 @@ async def add_group_cool_down_event( @classmethod async def check_group_cool_down_event(cls, plugin: str, group_id: int) -> Result.IntResult: + """ + :return: + result = 2: Success with CoolDown Event expired + result = 1: Success with CoolDown Event exist + result = 0: Success with CoolDown Event not found + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): @@ -161,7 +203,10 @@ async def check_group_cool_down_event(cls, plugin: str, group_id: int) -> Result ) event = session_result.scalar_one() stop_at = event.stop_at - result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) + if datetime.now() > stop_at: + result = Result.IntResult(error=False, info='Success, CoolDown expired', result=2) + else: + result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) except NoResultFound: result = Result.IntResult(error=False, info='NoResultFound', result=0) except MultipleResultsFound: @@ -173,6 +218,11 @@ async def check_group_cool_down_event(cls, plugin: str, group_id: int) -> Result @classmethod async def add_user_cool_down_event( cls, plugin: str, user_id: int, stop_at: datetime, description: str = None) -> Result.IntResult: + """ + :return: + result = 0: Success + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -206,6 +256,13 @@ async def add_user_cool_down_event( @classmethod async def check_user_cool_down_event(cls, plugin: str, user_id: int) -> Result.IntResult: + """ + :return: + result = 2: Success with CoolDown Event expired + result = 1: Success with CoolDown Event exist + result = 0: Success with CoolDown Event not found + result = -1: Error + """ async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): @@ -218,7 +275,10 @@ async def check_user_cool_down_event(cls, plugin: str, user_id: int) -> Result.I ) event = session_result.scalar_one() stop_at = event.stop_at - result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) + if datetime.now() > stop_at: + result = Result.IntResult(error=False, info='Success, CoolDown expired', result=2) + else: + result = Result.IntResult(error=False, info=f'CoolDown until: {stop_at}', result=1) except NoResultFound: result = Result.IntResult(error=False, info='NoResultFound', result=0) except MultipleResultsFound: diff --git a/omega_miya/utils/Omega_Base/model/friend.py b/omega_miya/utils/Omega_Base/model/friend.py index d0bed521..f1815140 100644 --- a/omega_miya/utils/Omega_Base/model/friend.py +++ b/omega_miya/utils/Omega_Base/model/friend.py @@ -1,17 +1,26 @@ +from typing import Optional +from datetime import datetime from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result from omega_miya.utils.Omega_Base.tables import Friends, User, Subscription, UserSub from .user import DBUser +from .bot_self import DBBot from .subscription import DBSubscription -from typing import Optional -from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound class DBFriend(DBUser): + def __init__(self, user_id: int, self_bot: DBBot): + super().__init__(user_id) + self.self_bot = self_bot + @classmethod - async def list_exist_friends(cls) -> Result.ListResult: + async def list_exist_friends(cls, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -19,7 +28,8 @@ async def list_exist_friends(cls) -> Result.ListResult: session_result = await session.execute( select(User.qq). join(Friends). - where(User.id == Friends.user_id) + where(User.id == Friends.user_id). + where(Friends.bot_self_id == self_bot_id_result.result) ) exist_friends = [x for x in session_result.scalars().all()] result = Result.ListResult(error=False, info='Success', result=exist_friends) @@ -32,7 +42,12 @@ async def list_exist_friends(cls) -> Result.ListResult: return result @classmethod - async def list_exist_friends_by_private_permission(cls, private_permission: int) -> Result.ListResult: + async def list_exist_friends_by_private_permission( + cls, private_permission: int, self_bot: DBBot) -> Result.ListResult: + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -41,6 +56,7 @@ async def list_exist_friends_by_private_permission(cls, private_permission: int) select(User.qq). join(Friends). where(User.id == Friends.user_id). + where(Friends.bot_self_id == self_bot_id_result.result). where(Friends.private_permissions == private_permission) ) exist_friends = [x for x in session_result.scalars().all()] @@ -53,10 +69,14 @@ async def list_exist_friends_by_private_permission(cls, private_permission: int) result = Result.ListResult(error=True, info=repr(e), result=[]) return result - async def exist(self) -> bool: + async def friend_id(self) -> Result.IntResult: user_id_result = await self.id() if user_id_result.error: - return False + return Result.IntResult(error=True, info='User not exist', result=-1) + + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) async_session = NBdb().get_async_session() async with async_session() as session: @@ -64,14 +84,22 @@ async def exist(self) -> bool: async with session.begin(): session_result = await session.execute( select(Friends.id). - join(User). - where(Friends.user_id == User.id). + where(Friends.bot_self_id == self_bot_id_result.result). where(Friends.user_id == user_id_result.result) ) - exist_friend = session_result.scalar_one() - return True - except Exception: - return False + friend_table_id = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=friend_table_id) + except NoResultFound: + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def exist(self) -> bool: + result = await self.friend_id() + return result.success() async def set_friend( self, nickname: str, remark: Optional[str] = None, private_permissions: Optional[int] = None @@ -81,13 +109,19 @@ async def set_friend( if user_id_result.error: return Result.IntResult(error=True, info='User not exist', result=-1) + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): try: session_result = await session.execute( - select(Friends).where(Friends.user_id == user_id_result.result) + select(Friends). + where(Friends.user_id == user_id_result.result). + where(Friends.bot_self_id == self_bot_id_result.result) ) exist_friend = session_result.scalar_one() exist_friend.nickname = nickname @@ -98,10 +132,12 @@ async def set_friend( result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: if private_permissions: - new_friend = Friends(user_id=user_id_result.result, nickname=nickname, remark=remark, + new_friend = Friends(user_id=user_id_result.result, bot_self_id=self_bot_id_result.result, + nickname=nickname, remark=remark, private_permissions=private_permissions, created_at=datetime.now()) else: - new_friend = Friends(user_id=user_id_result.result, nickname=nickname, remark=remark, + new_friend = Friends(user_id=user_id_result.result, bot_self_id=self_bot_id_result.result, + nickname=nickname, remark=remark, private_permissions=0, created_at=datetime.now()) session.add(new_friend) result = Result.IntResult(error=False, info='Success added', result=0) @@ -115,24 +151,18 @@ async def set_friend( return result async def del_friend(self) -> Result.IntResult: - user_id_result = await self.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.IntResult(error=True, info='Friend not exist', result=-1) async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): - # 清空订阅 + # 删除好友表中该好友信息 session_result = await session.execute( - select(UserSub).where(UserSub.user_id == user_id_result.result) - ) - for exist_user_sub in session_result.scalars().all(): - await session.delete(exist_user_sub) - - # 删除好友表中该群组 - session_result = await session.execute( - select(Friends).where(Friends.user_id == user_id_result.result) + select(Friends). + where(Friends.id == friend_id_result.result) ) exist_friend = session_result.scalar_one() await session.delete(exist_friend) @@ -154,12 +184,18 @@ async def set_private_permission(self, private_permissions: int) -> Result.IntRe if user_id_result.error: return Result.IntResult(error=True, info='User not exist', result=-1) + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): session_result = await session.execute( - select(Friends).where(Friends.user_id == user_id_result.result) + select(Friends). + where(Friends.user_id == user_id_result.result). + where(Friends.bot_self_id == self_bot_id_result.result) ) exist_friend = session_result.scalar_one() exist_friend.private_permissions = private_permissions @@ -182,12 +218,18 @@ async def get_private_permission(self) -> Result.IntResult: if user_id_result.error: return Result.IntResult(error=True, info='User not exist', result=-1) + self_bot_id_result = await self.self_bot.id() + if self_bot_id_result.error: + return Result.IntResult(error=True, info='Bot not exist', result=-1) + async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): session_result = await session.execute( - select(Friends.private_permissions).where(Friends.user_id == user_id_result.result) + select(Friends.private_permissions). + where(Friends.user_id == user_id_result.result). + where(Friends.bot_self_id == self_bot_id_result.result) ) private_permissions = session_result.scalar_one() result = Result.IntResult(error=False, info='Success', result=private_permissions) @@ -199,17 +241,13 @@ async def get_private_permission(self) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def subscription_list(self) -> Result.ListResult: + async def subscription_list(self) -> Result.TupleListResult: """ :return: Result: List[Tuple[sub_type, sub_id, up_name]] """ - friend_check = await self.exist() - if not friend_check: - return Result.ListResult(error=True, info='Not friend', result=[]) - - user_id_result = await self.id() - if user_id_result.error: - return Result.ListResult(error=True, info='User not exist', result=[]) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.TupleListResult(error=True, info='Friend not exist', result=[]) async_session = NBdb().get_async_session() async with async_session() as session: @@ -219,26 +257,22 @@ async def subscription_list(self) -> Result.ListResult: select(Subscription.sub_type, Subscription.sub_id, Subscription.up_name). join(UserSub). where(Subscription.id == UserSub.sub_id). - where(UserSub.user_id == user_id_result.result) + where(UserSub.user_id == friend_id_result.result) ) res = [(x[0], x[1], x[2]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) + result = Result.TupleListResult(error=False, info='Success', result=res) except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) + result = Result.TupleListResult(error=True, info=repr(e), result=[]) return result - async def subscription_list_by_type(self, sub_type: int) -> Result.ListResult: + async def subscription_list_by_type(self, sub_type: int) -> Result.TupleListResult: """ :param sub_type: 订阅类型 :return: Result: List[Tuple[sub_id, up_name]] """ - friend_check = await self.exist() - if not friend_check: - return Result.ListResult(error=True, info='Not friend', result=[]) - - user_id_result = await self.id() - if user_id_result.error: - return Result.ListResult(error=True, info='User not exist', result=[]) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.TupleListResult(error=True, info='Friend not exist', result=[]) async_session = NBdb().get_async_session() async with async_session() as session: @@ -249,22 +283,18 @@ async def subscription_list_by_type(self, sub_type: int) -> Result.ListResult: join(UserSub). where(Subscription.sub_type == sub_type). where(Subscription.id == UserSub.sub_id). - where(UserSub.user_id == user_id_result.result) + where(UserSub.user_id == friend_id_result.result) ) res = [(x[0], x[1]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) + result = Result.TupleListResult(error=False, info='Success', result=res) except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) + result = Result.TupleListResult(error=True, info=repr(e), result=[]) return result async def subscription_add(self, sub: DBSubscription, user_sub_info: str = None) -> Result.IntResult: - friend_check = await self.exist() - if not friend_check: - return Result.IntResult(error=True, info='Not friend', result=-1) - - user_id_result = await self.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.IntResult(error=True, info='Friend not exist', result=-1) sub_id_result = await sub.id() if sub_id_result.error: @@ -277,7 +307,7 @@ async def subscription_add(self, sub: DBSubscription, user_sub_info: str = None) try: session_result = await session.execute( select(UserSub). - where(UserSub.user_id == user_id_result.result). + where(UserSub.user_id == friend_id_result.result). where(UserSub.sub_id == sub_id_result.result) ) # 订阅关系已存在, 更新信息 @@ -286,7 +316,7 @@ async def subscription_add(self, sub: DBSubscription, user_sub_info: str = None) exist_subscription.updated_at = datetime.now() result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: - subscription = UserSub(sub_id=sub_id_result.result, user_id=user_id_result.result, + subscription = UserSub(sub_id=sub_id_result.result, user_id=friend_id_result.result, user_sub_info=user_sub_info, created_at=datetime.now()) session.add(subscription) result = Result.IntResult(error=False, info='Success added', result=0) @@ -300,9 +330,9 @@ async def subscription_add(self, sub: DBSubscription, user_sub_info: str = None) return result async def subscription_del(self, sub: DBSubscription) -> Result.IntResult: - user_id_result = await self.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.IntResult(error=True, info='Friend not exist', result=-1) sub_id_result = await sub.id() if sub_id_result.error: @@ -314,7 +344,7 @@ async def subscription_del(self, sub: DBSubscription) -> Result.IntResult: async with session.begin(): session_result = await session.execute( select(UserSub). - where(UserSub.user_id == user_id_result.result). + where(UserSub.user_id == friend_id_result.result). where(UserSub.sub_id == sub_id_result.result) ) exist_subscription = session_result.scalar_one() @@ -333,16 +363,16 @@ async def subscription_del(self, sub: DBSubscription) -> Result.IntResult: return result async def subscription_clear(self) -> Result.IntResult: - user_id_result = await self.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.IntResult(error=True, info='Friend not exist', result=-1) async_session = NBdb().get_async_session() async with async_session() as session: try: async with session.begin(): session_result = await session.execute( - select(UserSub).where(UserSub.user_id == user_id_result.result) + select(UserSub).where(UserSub.user_id == friend_id_result.result) ) for exist_user_sub in session_result.scalars().all(): await session.delete(exist_user_sub) @@ -354,9 +384,9 @@ async def subscription_clear(self) -> Result.IntResult: return result async def subscription_clear_by_type(self, sub_type: int) -> Result.IntResult: - user_id_result = await self.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) + friend_id_result = await self.friend_id() + if friend_id_result.error: + return Result.IntResult(error=True, info='Friend not exist', result=-1) async_session = NBdb().get_async_session() async with async_session() as session: @@ -366,7 +396,7 @@ async def subscription_clear_by_type(self, sub_type: int) -> Result.IntResult: select(UserSub).join(Subscription). where(UserSub.sub_id == Subscription.id). where(Subscription.sub_type == sub_type). - where(UserSub.user_id == user_id_result.result) + where(UserSub.user_id == friend_id_result.result) ) for exist_user_sub in session_result.scalars().all(): await session.delete(exist_user_sub) diff --git a/omega_miya/utils/Omega_Base/model/group.py b/omega_miya/utils/Omega_Base/model/group.py index 615c3ac5..b07b7d83 100644 --- a/omega_miya/utils/Omega_Base/model/group.py +++ b/omega_miya/utils/Omega_Base/model/group.py @@ -1,10 +1,6 @@ from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import \ - User, Group, UserGroup, Vocation, Skill, UserSkill, Subscription, GroupSub, AuthGroup, EmailBox, GroupEmailBox -from .user import DBUser, DBSkill -from .subscription import DBSubscription -from .mail import DBEmailBox +from omega_miya.utils.Omega_Base.tables import Group from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound @@ -68,8 +64,7 @@ async def add(self, name: str) -> Result.IntResult: exist_group.updated_at = datetime.now() result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: - new_group = Group(group_id=self.group_id, name=name, notice_permissions=0, - command_permissions=0, permission_level=0, created_at=datetime.now()) + new_group = Group(group_id=self.group_id, name=name, created_at=datetime.now()) session.add(new_group) result = Result.IntResult(error=False, info='Success added', result=0) await session.commit() @@ -90,34 +85,6 @@ async def delete(self) -> Result.IntResult: async with async_session() as session: try: async with session.begin(): - # 清空权限节点 - session_result = await session.execute( - select(AuthGroup).where(AuthGroup.group_id == id_result.result) - ) - for exist_auth_node in session_result.scalars().all(): - await session.delete(exist_auth_node) - - # 清空群成员列表 - session_result = await session.execute( - select(UserGroup).where(UserGroup.group_id == id_result.result) - ) - for exist_user in session_result.scalars().all(): - await session.delete(exist_user) - - # 清空订阅 - session_result = await session.execute( - select(GroupSub).where(GroupSub.group_id == id_result.result) - ) - for exist_group_sub in session_result.scalars().all(): - await session.delete(exist_group_sub) - - # 清空绑定邮箱 - session_result = await session.execute( - select(GroupEmailBox).where(GroupEmailBox.group_id == id_result.result) - ) - for exist_mailbox in session_result.scalars().all(): - await session.delete(exist_mailbox) - # 删除群组表中该群组 session_result = await session.execute( select(Group).where(Group.group_id == self.group_id) @@ -137,621 +104,17 @@ async def delete(self) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def member_list(self) -> Result.ListResult: - id_result = await self.id() - if id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(User.qq, UserGroup.user_group_nickname). - join(UserGroup). - where(UserGroup.group_id == id_result.result) - ) - res = [(x[0], x[1]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def member_add(self, user: DBUser, user_group_nickname: str) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - user_id_result = await user.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - # 查询成员-群组表中用户-群关系 - try: - # 用户-群关系已存在, 更新用户群昵称 - session_result = await session.execute( - select(UserGroup). - where(UserGroup.user_id == user_id_result.result). - where(UserGroup.group_id == group_id_result.result) - ) - exist_user = session_result.scalar_one() - exist_user.user_group_nickname = user_group_nickname - exist_user.updated_at = datetime.now() - result = Result.IntResult(error=False, info='Success upgraded', result=0) - except NoResultFound: - # 不存在关系则添加新成员 - new_user = UserGroup(user_id=user_id_result.result, group_id=group_id_result.result, - user_group_nickname=user_group_nickname, created_at=datetime.now()) - session.add(new_user) - result = Result.IntResult(error=False, info='Success added', result=0) - await session.commit() - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def member_del(self, user: DBUser) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - user_id_result = await user.id() - if user_id_result.error: - return Result.IntResult(error=True, info='User not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(UserGroup). - where(UserGroup.user_id == user_id_result.result). - where(UserGroup.group_id == group_id_result.result) - ) - exist_user = session_result.scalar_one() - await session.delete(exist_user) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except NoResultFound: - await session.rollback() - result = Result.IntResult(error=True, info='NoResultFound', result=-1) - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def member_clear(self) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(UserGroup).where(UserGroup.group_id == group_id_result.result) - ) - for exist_user in session_result.scalars().all(): - await session.delete(exist_user) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def permission_reset(self) -> Result.IntResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(Group).where(Group.group_id == self.group_id) - ) - exist_group = session_result.scalar_one() - exist_group.notice_permissions = 0 - exist_group.command_permissions = 0 - exist_group.permission_level = 0 - exist_group.updated_at = datetime.now() - result = Result.IntResult(error=False, info='Success upgraded', result=0) - await session.commit() - except NoResultFound: - await session.rollback() - result = Result.IntResult(error=True, info='NoResultFound', result=-1) - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def permission_set(self, notice: int = 0, command: int = 0, level: int = 0) -> Result.IntResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(Group).where(Group.group_id == self.group_id) - ) - exist_group = session_result.scalar_one() - exist_group.notice_permissions = notice - exist_group.command_permissions = command - exist_group.permission_level = level - exist_group.updated_at = datetime.now() - result = Result.IntResult(error=False, info='Success upgraded', result=0) - await session.commit() - except NoResultFound: - await session.rollback() - result = Result.IntResult(error=True, info='NoResultFound', result=-1) - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def permission_info(self) -> Result.IntTupleResult: - """ - :return: Result: Tuple[Notice_permission, Command_permission, Permission_level] - """ + @classmethod + async def list_all_group(cls) -> Result.IntListResult: async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: session_result = await session.execute( - select(Group.notice_permissions, Group.command_permissions, Group.permission_level). - where(Group.group_id == self.group_id) - ) - notice, command, level = session_result.one() - result = Result.IntTupleResult(error=False, info='Success', result=(notice, command, level)) - except NoResultFound: - result = Result.IntTupleResult(error=True, info='NoResultFound', result=(-1, -1, -1)) - except MultipleResultsFound: - result = Result.IntTupleResult(error=True, info='MultipleResultsFound', result=(-1, -1, -1)) - except Exception as e: - result = Result.IntTupleResult(error=True, info=repr(e), result=(-1, -1, -1)) - return result - - async def permission_notice(self) -> Result.IntResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Group.notice_permissions).where(Group.group_id == self.group_id) - ) - res = session_result.scalar_one() - result = Result.IntResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def permission_command(self) -> Result.IntResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Group.command_permissions).where(Group.group_id == self.group_id) - ) - res = session_result.scalar_one() - result = Result.IntResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def permission_level(self) -> Result.IntResult: - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Group.permission_level).where(Group.group_id == self.group_id) - ) - res = session_result.scalar_one() - result = Result.IntResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def idle_member_list(self) -> Result.ListResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - res = [] - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - # 查询该群组中所有没有假的人 - session_result = await session.execute( - select(User.id, UserGroup.user_group_nickname). - join(Vocation).join(UserGroup). - where(User.id == Vocation.user_id). - where(User.id == UserGroup.user_id). - where(Vocation.status == 0). - where(UserGroup.group_id == group_id_result.result) - ) - user_res = [(x[0], x[1]) for x in session_result.all()] - # 查对应每个人的技能 - for user_id, nickname in user_res: - session_result = await session.execute( - select(Skill.name).join(UserSkill).join(User). - where(Skill.id == UserSkill.skill_id). - where(UserSkill.user_id == User.id). - where(User.id == user_id) - ) - user_skill_res = [x for x in session_result.scalars().all()] - if user_skill_res: - user_skill_text = '/'.join(user_skill_res) - else: - user_skill_text = '暂无技能' - res.append((nickname, user_skill_text)) - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def idle_skill_list(self, skill: DBSkill) -> Result.ListResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - skill_id_result = await skill.id() - if skill_id_result.error: - return Result.ListResult(error=True, info='Skill not exist', result=[]) - - res = [] - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - # 查询这这个技能有那些人会 - session_result = await session.execute( - select(User.id, UserGroup.user_group_nickname). - join(UserSkill).join(UserGroup). - where(User.id == UserSkill.user_id). - where(User.id == UserGroup.user_id). - where(UserSkill.skill_id == skill_id_result.result). - where(UserGroup.group_id == group_id_result.result) - ) - user_res = [(x[0], x[1]) for x in session_result.all()] - # 查这个人是不是空闲 - for user_id, nickname in user_res: - session_result = await session.execute( - select(Vocation.status).where(Vocation.user_id == user_id) - ) - # 如果空闲则把这个人昵称放进结果列表里面 - if session_result.scalar_one() == 0: - res.append(nickname) - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def vocation_member_list(self) -> Result.ListResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - # 查询所有没有假的人 - session_result = await session.execute( - select(UserGroup.user_group_nickname, Vocation.stop_at). - select_from(UserGroup).join(User). - where(UserGroup.user_id == User.id). - where(User.id == Vocation.user_id). - where(Vocation.status == 1). - where(UserGroup.group_id == group_id_result.result) - ) - res = [(x[0], x[1]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def init_member_status(self) -> Result.IntResult: - member_list_res = await self.member_list() - for user_qq, nickname in member_list_res.result: - user = DBUser(user_id=user_qq) - user_status_res = await user.status() - if user_status_res.error: - await user.status_set(status=0) - return Result.IntResult(error=False, info='ignore', result=0) - - async def subscription_list(self) -> Result.ListResult: - """ - :return: Result: List[Tuple[sub_type, sub_id, up_name]] - """ - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Subscription.sub_type, Subscription.sub_id, Subscription.up_name). - join(GroupSub). - where(Subscription.id == GroupSub.sub_id). - where(GroupSub.group_id == group_id_result.result) - ) - res = [(x[0], x[1], x[2]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def subscription_list_by_type(self, sub_type: int) -> Result.ListResult: - """ - :param sub_type: 订阅类型 - :return: Result: List[Tuple[sub_id, up_name]] - """ - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(Subscription.sub_id, Subscription.up_name). - join(GroupSub). - where(Subscription.sub_type == sub_type). - where(Subscription.id == GroupSub.sub_id). - where(GroupSub.group_id == group_id_result.result) - ) - res = [(x[0], x[1]) for x in session_result.all()] - result = Result.ListResult(error=False, info='Success', result=res) - except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def subscription_add(self, sub: DBSubscription, group_sub_info: str = None) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - sub_id_result = await sub.id() - if sub_id_result.error: - return Result.IntResult(error=True, info='Subscription not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - try: - session_result = await session.execute( - select(GroupSub). - where(GroupSub.group_id == group_id_result.result). - where(GroupSub.sub_id == sub_id_result.result) - ) - # 订阅关系已存在, 更新信息 - exist_subscription = session_result.scalar_one() - exist_subscription.group_sub_info = group_sub_info - exist_subscription.updated_at = datetime.now() - result = Result.IntResult(error=False, info='Success upgraded', result=0) - except NoResultFound: - subscription = GroupSub(sub_id=sub_id_result.result, group_id=group_id_result.result, - group_sub_info=group_sub_info, created_at=datetime.now()) - session.add(subscription) - result = Result.IntResult(error=False, info='Success added', result=0) - await session.commit() - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def subscription_del(self, sub: DBSubscription) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - sub_id_result = await sub.id() - if sub_id_result.error: - return Result.IntResult(error=True, info='Subscription not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(GroupSub). - where(GroupSub.group_id == group_id_result.result). - where(GroupSub.sub_id == sub_id_result.result) - ) - exist_subscription = session_result.scalar_one() - await session.delete(exist_subscription) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except NoResultFound: - await session.rollback() - result = Result.IntResult(error=True, info='NoResultFound', result=-1) - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def subscription_clear(self) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(GroupSub).where(GroupSub.group_id == group_id_result.result) - ) - for exist_group_sub in session_result.scalars().all(): - await session.delete(exist_group_sub) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def subscription_clear_by_type(self, sub_type: int) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(GroupSub).join(Subscription). - where(GroupSub.sub_id == Subscription.id). - where(Subscription.sub_type == sub_type). - where(GroupSub.group_id == group_id_result.result) - ) - for exist_group_sub in session_result.scalars().all(): - await session.delete(exist_group_sub) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def mailbox_list(self) -> Result.ListResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.ListResult(error=True, info='Group not exist', result=[]) - - async_session = NBdb().get_async_session() - async with async_session() as session: - async with session.begin(): - try: - session_result = await session.execute( - select(EmailBox.address). - join(GroupEmailBox). - where(EmailBox.id == GroupEmailBox.email_box_id). - where(GroupEmailBox.group_id == group_id_result.result) + select(Group.group_id).order_by(Group.group_id) ) res = [x for x in session_result.scalars().all()] - result = Result.ListResult(error=False, info='Success', result=res) + result = Result.IntListResult(error=False, info='Success', result=res) except Exception as e: - result = Result.ListResult(error=True, info=repr(e), result=[]) - return result - - async def mailbox_add(self, mailbox: DBEmailBox, mailbox_info: str = None) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - mailbox_id_result = await mailbox.id() - if mailbox_id_result.error: - return Result.IntResult(error=True, info='Mailbox not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - try: - session_result = await session.execute( - select(GroupEmailBox). - where(GroupEmailBox.group_id == group_id_result.result). - where(GroupEmailBox.email_box_id == mailbox_id_result.result) - ) - # 群邮箱已存在, 更新信息 - exist_mailbox = session_result.scalar_one() - exist_mailbox.box_info = mailbox_info - exist_mailbox.updated_at = datetime.now() - result = Result.IntResult(error=False, info='Success upgraded', result=0) - except NoResultFound: - new_mailbox = GroupEmailBox(email_box_id=mailbox_id_result.result, - group_id=group_id_result.result, - box_info=mailbox_info, created_at=datetime.now()) - session.add(new_mailbox) - result = Result.IntResult(error=False, info='Success added', result=0) - await session.commit() - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def mailbox_del(self, mailbox: DBEmailBox) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - mailbox_id_result = await mailbox.id() - if mailbox_id_result.error: - return Result.IntResult(error=True, info='Mailbox not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(GroupEmailBox). - where(GroupEmailBox.group_id == group_id_result.result). - where(GroupEmailBox.email_box_id == mailbox_id_result.result) - ) - exist_mailbox = session_result.scalar_one() - await session.delete(exist_mailbox) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except NoResultFound: - await session.rollback() - result = Result.IntResult(error=True, info='NoResultFound', result=-1) - except MultipleResultsFound: - await session.rollback() - result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) - return result - - async def mailbox_clear(self) -> Result.IntResult: - group_id_result = await self.id() - if group_id_result.error: - return Result.IntResult(error=True, info='Group not exist', result=-1) - - async_session = NBdb().get_async_session() - async with async_session() as session: - try: - async with session.begin(): - session_result = await session.execute( - select(GroupEmailBox).where(GroupEmailBox.group_id == group_id_result.result) - ) - for exist_mailbox in session_result.scalars().all(): - await session.delete(exist_mailbox) - await session.commit() - result = Result.IntResult(error=False, info='Success', result=0) - except Exception as e: - await session.rollback() - result = Result.IntResult(error=True, info=repr(e), result=-1) + result = Result.IntListResult(error=True, info=repr(e), result=[]) return result diff --git a/omega_miya/utils/Omega_Base/model/history.py b/omega_miya/utils/Omega_Base/model/history.py index 23accb2c..77aa1cfc 100644 --- a/omega_miya/utils/Omega_Base/model/history.py +++ b/omega_miya/utils/Omega_Base/model/history.py @@ -1,6 +1,8 @@ from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result from omega_miya.utils.Omega_Base.tables import History +from sqlalchemy.future import select +from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound from datetime import datetime @@ -11,9 +13,8 @@ def __init__(self, time: int, self_id: int, post_type: str, detail_type: str): self.post_type = post_type self.detail_type = detail_type - async def add(self, sub_type: str = None, event_id: int = None, group_id: int = None, - user_id: int = None, user_name: str = None, - raw_data: str = None, msg_data: str = None) -> Result.IntResult: + async def add(self, sub_type: str = 'Undefined', event_id: int = 0, group_id: int = -1, user_id: int = -1, + user_name: str = None, raw_data: str = None, msg_data: str = None) -> Result.IntResult: async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -29,3 +30,106 @@ async def add(self, sub_type: str = None, event_id: int = None, group_id: int = await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + @classmethod + async def search_unique_msg( + cls, + self_id: int, + post_type: str, + detail_type: str, + sub_type: str, + event_id: int, + group_id: int, + user_id: int + ) -> Result.AnyResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(History). + where(History.self_id == self_id). + where(History.post_type == post_type). + where(History.detail_type == detail_type). + where(History.sub_type == sub_type). + where(History.event_id == event_id). + where(History.group_id == group_id). + where(History.user_id == user_id) + ) + exist_history = session_result.scalar_one() + result = Result.AnyResult(error=False, info='Success', result=exist_history) + except NoResultFound: + result = Result.AnyResult(error=True, info='NoResultFound', result=None) + except MultipleResultsFound: + result = Result.AnyResult(error=True, info='MultipleResultsFound', result=None) + except Exception as e: + result = Result.AnyResult(error=True, info=repr(e), result=None) + return result + + @classmethod + async def search_msgs( + cls, + self_id: int, + post_type: str, + detail_type: str, + sub_type: str = 'Undefined', + event_id: int = 0, + group_id: int = -1, + user_id: int = -1 + ) -> Result.ListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(History). + where(History.self_id == self_id). + where(History.post_type == post_type). + where(History.detail_type == detail_type). + where(History.sub_type == sub_type). + where(History.event_id == event_id). + where(History.group_id == group_id). + where(History.user_id == user_id). + order_by(History.time.desc()) + ) + exist_history = [x for x in session_result.scalars()] + result = Result.ListResult(error=False, info='Success', result=exist_history) + except NoResultFound: + result = Result.ListResult(error=True, info='NoResultFound', result=[]) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def search_msgs_data( + cls, + self_id: int, + post_type: str, + detail_type: str, + sub_type: str = 'Undefined', + event_id: int = 0, + group_id: int = -1, + user_id: int = -1 + ) -> Result.TextListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(History.msg_data). + where(History.self_id == self_id). + where(History.post_type == post_type). + where(History.detail_type == detail_type). + where(History.sub_type == sub_type). + where(History.event_id == event_id). + where(History.group_id == group_id). + where(History.user_id == user_id). + order_by(History.time.desc()) + ) + exist_history = [x for x in session_result.scalars()] + result = Result.TextListResult(error=False, info='Success', result=exist_history) + except NoResultFound: + result = Result.TextListResult(error=True, info='NoResultFound', result=[]) + except Exception as e: + result = Result.TextListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/pixiv_user_artwork.py b/omega_miya/utils/Omega_Base/model/pixiv_user_artwork.py new file mode 100644 index 00000000..328800fe --- /dev/null +++ b/omega_miya/utils/Omega_Base/model/pixiv_user_artwork.py @@ -0,0 +1,108 @@ +""" +@Author : Ailitonia +@Date : 2021/06/01 21:02 +@FileName : pixiv_user_artwork.py +@Project : nonebot2_miya +@Description : PixivUserArtwork Model +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from omega_miya.utils.Omega_Base.database import NBdb +from omega_miya.utils.Omega_Base.class_result import Result +from omega_miya.utils.Omega_Base.tables import PixivUserArtwork +from datetime import datetime +from sqlalchemy.future import select +from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound + + +class DBPixivUserArtwork(object): + def __init__(self, pid: int, uid: int): + self.pid = pid + self.uid = uid + + async def id(self) -> Result.IntResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivUserArtwork.id). + where(PixivUserArtwork.pid == self.pid). + where(PixivUserArtwork.uid == self.uid) + ) + artwork_table_id = session_result.scalar_one() + result = Result.IntResult(error=False, info='Success', result=artwork_table_id) + except NoResultFound: + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def exist(self) -> bool: + result = await self.id() + return result.success() + + async def add(self, uname: str, title: str) -> Result.IntResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivUserArtwork). + where(PixivUserArtwork.pid == self.pid). + where(PixivUserArtwork.uid == self.uid) + ) + exist_artwork = session_result.scalar_one() + exist_artwork.uname = uname + exist_artwork.title = title + exist_artwork.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Exist artwork updated', result=0) + except NoResultFound: + new_artwork = PixivUserArtwork(pid=self.pid, uid=self.uid, + title=title, uname=uname, created_at=datetime.now()) + session.add(new_artwork) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + @classmethod + async def list_all_artwork(cls) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivUserArtwork.pid).order_by(PixivUserArtwork.pid) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_artwork_by_uid(cls, uid: int) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivUserArtwork.pid). + where(PixivUserArtwork.uid == uid). + order_by(PixivUserArtwork.pid) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/pixivillust.py b/omega_miya/utils/Omega_Base/model/pixivillust.py index 4afd3546..94c65352 100644 --- a/omega_miya/utils/Omega_Base/model/pixivillust.py +++ b/omega_miya/utils/Omega_Base/model/pixivillust.py @@ -1,8 +1,7 @@ from typing import List from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import Pixiv, PixivT2I -from .pixivtag import DBPixivtag +from omega_miya.utils.Omega_Base.tables import Pixiv, PixivPage from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound @@ -36,18 +35,17 @@ async def exist(self) -> bool: result = await self.id() return result.success() - async def add(self, uid: int, title: str, uname: str, nsfw_tag: int, tags: List[str], url: str) -> Result.IntResult: - # 将tag写入pixiv_tag表 - for tag in tags: - _tag = DBPixivtag(tagname=tag) - await _tag.add() - + async def add( + self, + uid: int, title: str, uname: str, nsfw_tag: int, width: int, height: int, tags: List[str], url: str, + *, + force_tag: bool = False + ) -> Result.IntResult: tag_text = ','.join(tags) # 将作品信息写入pixiv_illust表 async_session = NBdb().get_async_session() async with async_session() as session: try: - need_upgrade_pixivt2i = False async with session.begin(): try: session_result = await session.execute( @@ -56,40 +54,66 @@ async def add(self, uid: int, title: str, uname: str, nsfw_tag: int, tags: List[ exist_illust = session_result.scalar_one() exist_illust.title = title exist_illust.uname = uname - if nsfw_tag > exist_illust.nsfw_tag: + if force_tag: + exist_illust.nsfw_tag = nsfw_tag + elif nsfw_tag > exist_illust.nsfw_tag: exist_illust.nsfw_tag = nsfw_tag + exist_illust.width = width + exist_illust.height = height exist_illust.tags = tag_text exist_illust.updated_at = datetime.now() result = Result.IntResult(error=False, info='Exist illust updated', result=0) except NoResultFound: new_illust = Pixiv(pid=self.pid, uid=uid, title=title, uname=uname, url=url, nsfw_tag=nsfw_tag, - tags=tag_text, created_at=datetime.now()) + width=width, height=height, tags=tag_text, created_at=datetime.now()) session.add(new_illust) - need_upgrade_pixivt2i = True result = Result.IntResult(error=False, info='Success added', result=0) await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result - # 写入tag_pixiv关联表 - # 获取本作品在illust表中的id - id_result = await self.id() - if id_result.success() and need_upgrade_pixivt2i: - _illust_id = id_result.result - # 根据作品tag依次写入tag_illust表 - async with session.begin(): - for tag in tags: - _tag = DBPixivtag(tagname=tag) - _tag_id_res = await _tag.id() - if not _tag_id_res.success(): - continue - _tag_id = _tag_id_res.result - try: - new_tag_illust = PixivT2I(illust_id=_illust_id, tag_id=_tag_id, - created_at=datetime.now()) - session.add(new_tag_illust) - except Exception as e: - continue - await session.commit() - result = Result.IntResult(error=False, info='Success added with tags', result=0) + async def upgrade_page( + self, + page: int, + original: str, + regular: str, + small: str, + thumb_mini: str + ) -> Result.IntResult: + pixiv_id_result = await self.id() + if pixiv_id_result.error: + return Result.IntResult(error=True, info='PixivIllust not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivPage). + where(PixivPage.page == page). + where(PixivPage.illust_id == pixiv_id_result.result) + ) + # 已存在, 更新信息 + exist_page = session_result.scalar_one() + exist_page.original = original + exist_page.regular = regular + exist_page.small = small + exist_page.thumb_mini = thumb_mini + exist_page.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + new_page = PixivPage(illust_id=pixiv_id_result.result, page=page, + original=original, regular=regular, small=small, thumb_mini=thumb_mini, + created_at=datetime.now()) + session.add(new_page) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() except MultipleResultsFound: await session.rollback() result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) @@ -98,6 +122,58 @@ async def add(self, uid: int, title: str, uname: str, nsfw_tag: int, tags: List[ result = Result.IntResult(error=True, info=repr(e), result=-1) return result + async def get_page(self, page: int = 0) -> Result.TextTupleResult: + """ + :param page: 页码 + :return: Result: Tuple[original, regular, small, thumb_mini] + """ + pixiv_id_result = await self.id() + if pixiv_id_result.error: + return Result.TextTupleResult(error=True, info='PixivIllust not exist', result=('', '', '', '')) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivPage.original, PixivPage.regular, PixivPage.small, PixivPage.thumb_mini). + where(PixivPage.page == page). + where(PixivPage.illust_id == pixiv_id_result.result) + ) + original, regular, small, thumb_mini = session_result.one() + result = Result.TextTupleResult(error=False, info='Success', + result=(original, regular, small, thumb_mini)) + except NoResultFound: + result = Result.TextTupleResult(error=True, info='NoResultFound', result=('', '', '', '')) + except MultipleResultsFound: + result = Result.TextTupleResult(error=True, info='MultipleResultsFound', result=('', '', '', '')) + except Exception as e: + result = Result.TextTupleResult(error=True, info=repr(e), result=('', '', '', '')) + return result + + async def get_all_page(self) -> Result.ListResult: + """ + :return: Result: List[Tuple[page, original, regular, small, thumb_mini]] + """ + pixiv_id_result = await self.id() + if pixiv_id_result.error: + return Result.ListResult(error=True, info='PixivIllust not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(PixivPage.page, + PixivPage.original, PixivPage.regular, PixivPage.small, PixivPage.thumb_mini). + where(PixivPage.illust_id == pixiv_id_result.result) + ) + res = [(x[0], x[1], x[2], x[3], x[4]) for x in session_result.all()] + result = Result.ListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + @classmethod async def rand_illust(cls, num: int, nsfw_tag: int) -> Result.ListResult: async_session = NBdb().get_async_session() @@ -136,6 +212,39 @@ async def status(cls) -> Result.DictResult: result = Result.DictResult(error=True, info=repr(e), result={}) return result + @classmethod + async def count_keywords(cls, keywords: List[str]) -> Result.DictResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + # 构造查询, 模糊搜索标题, 用户和tag + query = select(func.count(Pixiv.id)) + for keyword in keywords: + query = query.where(or_( + Pixiv.tags.ilike(f'%{keyword}%'), + Pixiv.uname.ilike(f'%{keyword}%'), + Pixiv.title.ilike(f'%{keyword}%') + )) + session_all_result = await session.execute(query) + all_count = session_all_result.scalar() + + session_moe_result = await session.execute(query.where(Pixiv.nsfw_tag == 0)) + moe_count = session_moe_result.scalar() + + session_setu_result = await session.execute(query.where(Pixiv.nsfw_tag == 1)) + setu_count = session_setu_result.scalar() + + session_r18_result = await session.execute(query.where(Pixiv.nsfw_tag == 2)) + r18_count = session_r18_result.scalar() + + res = {'total': int(all_count), 'moe': int(moe_count), + 'setu': int(setu_count), 'r18': int(r18_count)} + result = Result.DictResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.DictResult(error=True, info=repr(e), result={}) + return result + @classmethod async def list_illust( cls, keywords: List[str], num: int, nsfw_tag: int, acc_mode: bool = False) -> Result.ListResult: @@ -173,3 +282,76 @@ async def list_illust( except Exception as e: result = Result.ListResult(error=True, info=repr(e), result=[]) return result + + @classmethod + async def list_all_illust(cls) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute(select(Pixiv.pid)) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_all_illust_by_nsfw_tag(cls, nsfw_tag: int) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute(select(Pixiv.pid).where(Pixiv.nsfw_tag == nsfw_tag)) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def reset_all_nsfw_tag(cls) -> Result.IntResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute(select(Pixiv.pid)) + res = [x for x in session_result.scalars().all()] + for pid in res: + # print(f'reset nsfw tag: {pid}') + session_result = await session.execute( + select(Pixiv).where(Pixiv.pid == pid) + ) + exist_illust = session_result.scalar_one() + exist_illust.nsfw_tag = 0 + result = Result.IntResult(error=False, info='Exist illust updated', result=0) + await session.commit() + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + @classmethod + async def set_nsfw_tag(cls, tags: dict) -> Result.IntResult: + """ + :param tags: Dict[pid: int, nsfw_tag: int] + :return: + """ + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + for pid, nsfw_tag in tags.items(): + # print(f'set nsfw tag: {pid}, {nsfw_tag}') + nsfw_tag = str(nsfw_tag) + session_result = await session.execute( + select(Pixiv).where(Pixiv.pid == pid) + ) + exist_illust = session_result.scalar_one() + exist_illust.nsfw_tag = nsfw_tag + result = Result.IntResult(error=False, info='Exist illust updated', result=0) + await session.commit() + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result diff --git a/omega_miya/utils/Omega_Base/model/pixivision.py b/omega_miya/utils/Omega_Base/model/pixivision.py index 72e7dcad..489b5324 100644 --- a/omega_miya/utils/Omega_Base/model/pixivision.py +++ b/omega_miya/utils/Omega_Base/model/pixivision.py @@ -63,3 +63,18 @@ async def add(self, title: str, description: str, tags: str, illust_id: str, url await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + @classmethod + async def list_article_id(cls) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Pixivision.aid).order_by(Pixivision.aid) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/skill.py b/omega_miya/utils/Omega_Base/model/skill.py index 80dd45ad..9dde7c43 100644 --- a/omega_miya/utils/Omega_Base/model/skill.py +++ b/omega_miya/utils/Omega_Base/model/skill.py @@ -129,3 +129,18 @@ async def able_member_clear(self) -> Result.IntResult: await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + @classmethod + async def list_available_skill(cls) -> Result.TextListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Skill.name).order_by(Skill.name) + ) + res = [x for x in session_result.scalars().all()] + result = Result.TextListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.TextListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/subscription.py b/omega_miya/utils/Omega_Base/model/subscription.py index 9f036dbb..1ec38d89 100644 --- a/omega_miya/utils/Omega_Base/model/subscription.py +++ b/omega_miya/utils/Omega_Base/model/subscription.py @@ -1,6 +1,7 @@ from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import Subscription, Group, GroupSub, User, UserSub +from omega_miya.utils.Omega_Base.tables import Subscription, Group, BotGroup, GroupSub, User, Friends, UserSub +from .bot_self import DBBot from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound @@ -97,18 +98,73 @@ async def delete(self) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def sub_group_list(self) -> Result.ListResult: + async def get_name(self) -> Result.TextResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Subscription.up_name). + where(Subscription.sub_type == self.sub_type). + where(Subscription.sub_id == self.sub_id) + ) + subscription_up_name = session_result.scalar_one() + result = Result.TextResult(error=False, info='Success', result=subscription_up_name) + except NoResultFound: + result = Result.TextResult(error=True, info='NoResultFound', result='') + except MultipleResultsFound: + result = Result.TextResult(error=True, info='MultipleResultsFound', result='') + except Exception as e: + result = Result.TextResult(error=True, info=repr(e), result='') + return result + + async def sub_group_list(self, self_bot: DBBot) -> Result.ListResult: + id_result = await self.id() + if id_result.error: + return Result.ListResult(error=True, info='Subscription not exist', result=[]) + + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Group.group_id). + join(BotGroup).join(GroupSub). + where(Group.id == BotGroup.group_id). + where(BotGroup.id == GroupSub.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(GroupSub.sub_id == id_result.result) + ) + res = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + async def sub_group_list_by_notice_permission(self, self_bot: DBBot, notice_permission: int) -> Result.ListResult: id_result = await self.id() if id_result.error: return Result.ListResult(error=True, info='Subscription not exist', result=[]) + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: session_result = await session.execute( - select(Group.group_id).join(GroupSub). - where(Group.id == GroupSub.group_id). + select(Group.group_id). + join(BotGroup).join(GroupSub). + where(Group.id == BotGroup.group_id). + where(BotGroup.id == GroupSub.group_id). + where(BotGroup.bot_self_id == self_bot_id_result.result). + where(BotGroup.notice_permissions == notice_permission). where(GroupSub.sub_id == id_result.result) ) res = [x for x in session_result.scalars().all()] @@ -138,18 +194,53 @@ async def sub_group_clear(self) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def sub_user_list(self) -> Result.ListResult: + async def sub_user_list(self, self_bot: DBBot) -> Result.ListResult: id_result = await self.id() if id_result.error: return Result.ListResult(error=True, info='Subscription not exist', result=[]) + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + async_session = NBdb().get_async_session() async with async_session() as session: async with session.begin(): try: session_result = await session.execute( - select(User.qq).join(UserSub). - where(User.id == UserSub.user_id). + select(User.qq). + join(Friends).join(UserSub). + where(User.id == Friends.user_id). + where(Friends.id == UserSub.user_id). + where(Friends.bot_self_id == self_bot_id_result.result). + where(UserSub.sub_id == id_result.result) + ) + res = [x for x in session_result.scalars().all()] + result = Result.ListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.ListResult(error=True, info=repr(e), result=[]) + return result + + async def sub_user_list_by_private_permission(self, self_bot: DBBot, private_permission: int) -> Result.ListResult: + id_result = await self.id() + if id_result.error: + return Result.ListResult(error=True, info='Subscription not exist', result=[]) + + self_bot_id_result = await self_bot.id() + if self_bot_id_result.error: + return Result.ListResult(error=True, info='Bot not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(User.qq). + join(Friends).join(UserSub). + where(User.id == Friends.user_id). + where(Friends.id == UserSub.user_id). + where(Friends.bot_self_id == self_bot_id_result.result). + where(Friends.private_permissions == private_permission). where(UserSub.sub_id == id_result.result) ) res = [x for x in session_result.scalars().all()] @@ -178,3 +269,35 @@ async def sub_user_clear(self) -> Result.IntResult: await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + @classmethod + async def list_all_sub(cls) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Subscription.sub_id).order_by(Subscription.sub_id) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result + + @classmethod + async def list_sub_by_type(cls, sub_type: int) -> Result.IntListResult: + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(Subscription.sub_id). + where(Subscription.sub_type == sub_type). + order_by(Subscription.sub_id) + ) + res = [x for x in session_result.scalars().all()] + result = Result.IntListResult(error=False, info='Success', result=res) + except Exception as e: + result = Result.IntListResult(error=True, info=repr(e), result=[]) + return result diff --git a/omega_miya/utils/Omega_Base/model/user.py b/omega_miya/utils/Omega_Base/model/user.py index 54e27ce7..8c5d9b72 100644 --- a/omega_miya/utils/Omega_Base/model/user.py +++ b/omega_miya/utils/Omega_Base/model/user.py @@ -1,8 +1,10 @@ +from typing import List, Optional +from datetime import date, datetime +from dataclasses import dataclass from omega_miya.utils.Omega_Base.database import NBdb from omega_miya.utils.Omega_Base.class_result import Result -from omega_miya.utils.Omega_Base.tables import User, UserGroup, Skill, UserSkill, UserSub, Vocation, AuthUser +from omega_miya.utils.Omega_Base.tables import User, UserFavorability, UserSignIn, Skill, UserSkill, Vacation from .skill import DBSkill -from datetime import datetime from sqlalchemy.future import select from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound @@ -11,6 +13,13 @@ class DBUser(object): def __init__(self, user_id: int): self.qq = user_id + @dataclass + class DateListResult(Result.AnyResult): + result: List[date] + + def __repr__(self): + return f'' + async def id(self) -> Result.IntResult: async_session = NBdb().get_async_session() async with async_session() as session: @@ -51,7 +60,7 @@ async def nickname(self) -> Result.TextResult: result = Result.TextResult(error=True, info=repr(e), result='') return result - async def add(self, nickname: str, is_friend: int = 0, aliasname: str = None) -> Result.IntResult: + async def add(self, nickname: str, aliasname: str = None) -> Result.IntResult: async_session = NBdb().get_async_session() async with async_session() as session: try: @@ -66,14 +75,12 @@ async def add(self, nickname: str, is_friend: int = 0, aliasname: str = None) -> result = Result.IntResult(error=False, info='Nickname not change', result=0) else: exist_user.nickname = nickname - exist_user.is_friend = is_friend exist_user.aliasname = aliasname exist_user.updated_at = datetime.now() result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: # 不存在则成员表中添加新成员 - new_user = User(qq=self.qq, nickname=nickname, is_friend=is_friend, - aliasname=aliasname, created_at=datetime.now()) + new_user = User(qq=self.qq, nickname=nickname, aliasname=aliasname, created_at=datetime.now()) session.add(new_user) result = Result.IntResult(error=False, info='Success added', result=0) await session.commit() @@ -94,41 +101,6 @@ async def delete(self) -> Result.IntResult: async with async_session() as session: try: async with session.begin(): - # 清空该用户权限节点 - session_result = await session.execute( - select(AuthUser).where(AuthUser.user_id == id_result.result) - ) - for exist_auth_node in session_result.scalars().all(): - await session.delete(exist_auth_node) - - # 清空技能 - session_result = await session.execute( - select(UserSkill).where(UserSkill.user_id == id_result.result) - ) - for exist_skill in session_result.scalars().all(): - await session.delete(exist_skill) - - # 删除状态和假期 - session_result = await session.execute( - select(Vocation).where(Vocation.user_id == id_result.result) - ) - exist_status = session_result.scalar_one() - await session.delete(exist_status) - - # 清空订阅 - session_result = await session.execute( - select(UserSub).where(UserSub.user_id == id_result.result) - ) - for exist_user_sub in session_result.scalars().all(): - await session.delete(exist_user_sub) - - # 清空群成员表中该用户 - session_result = await session.execute( - select(UserGroup).where(UserGroup.user_id == id_result.result) - ) - for exist_user in session_result.scalars().all(): - await session.delete(exist_user) - # 删除用户表中用户 session_result = await session.execute( select(User).where(User.qq == self.qq) @@ -272,7 +244,7 @@ async def status(self) -> Result.IntResult: async with session.begin(): try: session_result = await session.execute( - select(Vocation.status).where(Vocation.user_id == user_id_result.result) + select(Vacation.status).where(Vacation.user_id == user_id_result.result) ) res = session_result.scalar_one() result = Result.IntResult(error=False, info='Success', result=res) @@ -280,7 +252,7 @@ async def status(self) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def vocation_status(self) -> Result.ListResult: + async def vacation_status(self) -> Result.ListResult: user_id_result = await self.id() if user_id_result.error: return Result.ListResult(error=True, info='User not exist', result=[-1, None]) @@ -290,8 +262,8 @@ async def vocation_status(self) -> Result.ListResult: async with session.begin(): try: session_result = await session.execute( - select(Vocation.status, Vocation.stop_at). - where(Vocation.user_id == user_id_result.result) + select(Vacation.status, Vacation.stop_at). + where(Vacation.user_id == user_id_result.result) ) res = session_result.one() result = Result.ListResult(error=False, info='Success', result=[res[0], res[1]]) @@ -310,7 +282,7 @@ async def status_set(self, status: int) -> Result.IntResult: async with session.begin(): try: session_result = await session.execute( - select(Vocation).where(Vocation.user_id == user_id_result.result) + select(Vacation).where(Vacation.user_id == user_id_result.result) ) exist_status = session_result.scalar_one() exist_status.status = status @@ -319,7 +291,7 @@ async def status_set(self, status: int) -> Result.IntResult: exist_status.updated_at = datetime.now() result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: - new_status = Vocation(user_id=user_id_result.result, status=status, created_at=datetime.now()) + new_status = Vacation(user_id=user_id_result.result, status=status, created_at=datetime.now()) session.add(new_status) result = Result.IntResult(error=False, info='Success set', result=0) await session.commit() @@ -331,7 +303,7 @@ async def status_set(self, status: int) -> Result.IntResult: result = Result.IntResult(error=True, info=repr(e), result=-1) return result - async def vocation_set(self, stop_time: datetime, reason: str = None) -> Result.IntResult: + async def vacation_set(self, stop_time: datetime, reason: str = None) -> Result.IntResult: user_id_result = await self.id() if user_id_result.error: return Result.IntResult(error=True, info='User not exist', result=-1) @@ -342,7 +314,7 @@ async def vocation_set(self, stop_time: datetime, reason: str = None) -> Result. async with session.begin(): try: session_result = await session.execute( - select(Vocation).where(Vocation.user_id == user_id_result.result) + select(Vacation).where(Vacation.user_id == user_id_result.result) ) exist_status = session_result.scalar_one() exist_status.status = 1 @@ -351,7 +323,7 @@ async def vocation_set(self, stop_time: datetime, reason: str = None) -> Result. exist_status.updated_at = datetime.now() result = Result.IntResult(error=False, info='Success upgraded', result=0) except NoResultFound: - new_status = Vocation(user_id=user_id_result.result, status=1, + new_status = Vacation(user_id=user_id_result.result, status=1, stop_at=stop_time, reason=reason, created_at=datetime.now()) session.add(new_status) result = Result.IntResult(error=False, info='Success set', result=0) @@ -374,7 +346,7 @@ async def status_del(self) -> Result.IntResult: try: async with session.begin(): session_result = await session.execute( - select(Vocation).where(Vocation.user_id == user_id_result.result) + select(Vacation).where(Vacation.user_id == user_id_result.result) ) exist_status = session_result.scalar_one() await session.delete(exist_status) @@ -384,3 +356,236 @@ async def status_del(self) -> Result.IntResult: await session.rollback() result = Result.IntResult(error=True, info=repr(e), result=-1) return result + + async def sign_in(self, *, sign_in_info: Optional[str] = 'Normal sign in') -> Result.IntResult: + """ + 签到 + :param sign_in_info: 签到信息 + :return: IntResult + 1: 已签到 + 0: 签到成功 + -1: 错误 + """ + user_id_result = await self.id() + if user_id_result.error: + return Result.IntResult(error=True, info='User not exist', result=-1) + + datetime_now = datetime.now() + date_now = datetime_now.date() + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(UserSignIn). + where(UserSignIn.user_id == user_id_result.result). + where(UserSignIn.sign_in_date == date_now) + ) + # 已有签到记录 + exist_sign_in = session_result.scalar_one() + exist_sign_in.sign_in_info = 'Duplicate sign in' + exist_sign_in.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=1) + except NoResultFound: + sign_in = UserSignIn(user_id=user_id_result.result, sign_in_date=date_now, + sign_in_info=sign_in_info, created_at=datetime.now()) + session.add(sign_in) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def sign_in_statistics(self) -> DateListResult: + """ + 查询所有签到记录 + :return: Result: List[sign_in_date] + """ + user_id_result = await self.id() + if user_id_result.error: + return self.DateListResult(error=True, info='User not exist', result=[]) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(UserSignIn.sign_in_date). + where(UserSignIn.user_id == user_id_result.result) + ) + res = [x for x in session_result.scalars().all()] + result = self.DateListResult(error=False, info='Success', result=res) + except Exception as e: + result = self.DateListResult(error=True, info=repr(e), result=[]) + return result + + async def sign_in_continuous_days(self) -> Result.IntResult: + """ + 查询到目前为止最长连续签到日数 + """ + sign_in_statistics_result = await self.sign_in_statistics() + if sign_in_statistics_result.error: + return Result.IntResult(error=True, info=sign_in_statistics_result.info, result=-1) + + # 还没有签到过 + if not sign_in_statistics_result.result: + return Result.IntResult(error=False, info='Success with sign in not found', result=0) + + datetime_now = datetime.now() + date_now = datetime_now.date() + date_now_toordinal = date_now.toordinal() + + # 先将签到记录中的日期转化为整数便于比较 + all_sign_in_list = list(set([x.toordinal() for x in sign_in_statistics_result.result])) + # 去重后由大到小排序 + all_sign_in_list.sort(reverse=True) + + # 如果今日日期不等于已签到日期最大值, 说明今日没有签到, 则连签日数为0 + if date_now_toordinal != all_sign_in_list[0]: + return Result.IntResult(error=False, info='Success with not sign in today', result=0) + + # 从大到小检查(即日期从后向前检查), 如果当日序号大小大于与今日日期之差, 说明在这里断签了 + for index, value in enumerate(all_sign_in_list): + if index != date_now_toordinal - value: + return Result.IntResult(error=False, info='Success with found interrupt', result=index) + else: + # 如果全部遍历完了那就说明全部没有断签 + return Result.IntResult(error=False, info='Success with all continuous', result=len(all_sign_in_list)) + + async def favorability_status(self) -> Result.TupleResult: + """ + 查询好感度记录 + :return: Result: + Tuple[status: str, mood: float, favorability: float, energy: float, currency: float, response_threshold: float] + """ + user_id_result = await self.id() + if user_id_result.error: + return Result.TupleResult(error=True, info='User not exist', result=()) + + async_session = NBdb().get_async_session() + async with async_session() as session: + async with session.begin(): + try: + session_result = await session.execute( + select(UserFavorability.status, + UserFavorability.mood, + UserFavorability.favorability, + UserFavorability.energy, + UserFavorability.currency, + UserFavorability.response_threshold). + where(UserFavorability.user_id == user_id_result.result) + ) + res = session_result.one() + result = Result.TupleResult(error=False, info='Success', result=res) + except NoResultFound: + result = Result.TupleResult(error=True, info='NoResultFound', result=()) + except MultipleResultsFound: + result = Result.TupleResult(error=True, info='MultipleResultsFound', result=()) + except Exception as e: + result = Result.TupleResult(error=True, info=repr(e), result=()) + return result + + async def favorability_reset( + self, + *, + status: str = 'normal', + mood: float = 0, + favorability: float = 0, + energy: float = 0, + currency: float = 0, + response_threshold: float = 0 + ) -> Result.IntResult: + user_id_result = await self.id() + if user_id_result.error: + return Result.IntResult(error=True, info='User not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + try: + session_result = await session.execute( + select(UserFavorability). + where(UserFavorability.user_id == user_id_result.result) + ) + # 已有好感度记录条目 + exist_favorability = session_result.scalar_one() + exist_favorability.status = status + exist_favorability.mood = mood + exist_favorability.favorability = favorability + exist_favorability.energy = energy + exist_favorability.currency = currency + exist_favorability.response_threshold = response_threshold + exist_favorability.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + except NoResultFound: + favorability = UserFavorability( + user_id=user_id_result.result, status=status, mood=mood, favorability=favorability, + energy=energy, currency=currency, response_threshold=response_threshold, + created_at=datetime.now()) + session.add(favorability) + result = Result.IntResult(error=False, info='Success added', result=0) + await session.commit() + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result + + async def favorability_add( + self, + *, + status: Optional[str] = None, + mood: Optional[float] = None, + favorability: Optional[float] = None, + energy: Optional[float] = None, + currency: Optional[float] = None, + response_threshold: Optional[float] = None + ) -> Result.IntResult: + user_id_result = await self.id() + if user_id_result.error: + return Result.IntResult(error=True, info='User not exist', result=-1) + + async_session = NBdb().get_async_session() + async with async_session() as session: + try: + async with session.begin(): + session_result = await session.execute( + select(UserFavorability). + where(UserFavorability.user_id == user_id_result.result) + ) + # 已有好感度记录条目 + exist_favorability = session_result.scalar_one() + if status: + exist_favorability.status = status + if mood: + exist_favorability.mood += mood + if favorability: + exist_favorability.favorability += favorability + if energy: + exist_favorability.energy += energy + if currency: + exist_favorability.currency += currency + if response_threshold: + exist_favorability.response_threshold += response_threshold + exist_favorability.updated_at = datetime.now() + result = Result.IntResult(error=False, info='Success upgraded', result=0) + await session.commit() + except NoResultFound: + await session.rollback() + result = Result.IntResult(error=True, info='NoResultFound', result=-1) + except MultipleResultsFound: + await session.rollback() + result = Result.IntResult(error=True, info='MultipleResultsFound', result=-1) + except Exception as e: + await session.rollback() + result = Result.IntResult(error=True, info=repr(e), result=-1) + return result diff --git a/omega_miya/utils/Omega_Base/tables.py b/omega_miya/utils/Omega_Base/tables.py index ebb5ddb4..87dcf34a 100644 --- a/omega_miya/utils/Omega_Base/tables.py +++ b/omega_miya/utils/Omega_Base/tables.py @@ -1,6 +1,8 @@ import nonebot +from datetime import datetime, date as date_ +from typing import Optional from sqlalchemy import Sequence, ForeignKey -from sqlalchemy import Column, Integer, BigInteger, String, DateTime +from sqlalchemy import Column, Integer, BigInteger, Float, String, Date, DateTime from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base @@ -25,7 +27,13 @@ class OmegaStatus(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, name, status, info, created_at=None, updated_at=None): + def __init__(self, + name: str, + status: int, + *, + info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.name = name self.status = status self.info = info @@ -37,7 +45,44 @@ def __repr__(self): f"created_at='{self.created_at}', updated_at='{self.updated_at}')>" -# 成员表 +# Bot表 对应不同机器人协议端 +class BotSelf(Base): + __tablename__ = f'{TABLE_PREFIX}bots' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + # 表结构 + id = Column(Integer, Sequence('bots_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + self_qq = Column(BigInteger, nullable=False, index=True, unique=True, comment='Bot的QQ号') + status = Column(Integer, nullable=False, comment='在线状态') + info = Column(String(1024), nullable=True, comment='信息') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + # 设置级联和关系加载 + bots_bot_friends = relationship('Friends', back_populates='bot_friends_back_bots', + cascade='all, delete-orphan', passive_deletes=True) + bots_bot_groups = relationship('BotGroup', back_populates='bot_groups_back_bots', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + self_qq: int, + status: int, + *, + info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.self_qq = self_qq + self.status = status + self.info = info + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + +# 用户表 class User(Base): __tablename__ = f'{TABLE_PREFIX}users' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} @@ -46,56 +91,76 @@ class User(Base): id = Column(Integer, Sequence('users_id_seq'), primary_key=True, nullable=False, index=True, unique=True) qq = Column(BigInteger, nullable=False, index=True, unique=True, comment='QQ号') nickname = Column(String(64), nullable=False, comment='昵称') - is_friend = Column(Integer, nullable=False, comment='是否为好友(已弃用)') aliasname = Column(String(64), nullable=True, comment='自定义名称') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - # 声明外键联系 - has_friends = relationship('Friends', back_populates='user_friend', uselist=False, - cascade="all, delete", passive_deletes=True) - has_skills = relationship('UserSkill', back_populates='user_skill', - cascade="all, delete", passive_deletes=True) - in_which_groups = relationship('UserGroup', back_populates='user_groups', - cascade="all, delete", passive_deletes=True) - vocation = relationship('Vocation', back_populates='vocation_for_user', uselist=False, - cascade="all, delete", passive_deletes=True) - user_auth = relationship('AuthUser', back_populates='auth_for_user', uselist=False, - cascade="all, delete", passive_deletes=True) - users_sub_what = relationship('UserSub', back_populates='users_sub', - cascade="all, delete", passive_deletes=True) - - def __init__(self, qq, nickname, is_friend=0, aliasname=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + users_bot_friends = relationship('Friends', back_populates='bot_friends_back_users', + cascade='all, delete-orphan', passive_deletes=True) + users_users_groups = relationship('UserGroup', back_populates='users_groups_back_users', + cascade='all, delete-orphan', passive_deletes=True) + users_users_skills = relationship('UserSkill', back_populates='users_skills_back_users', + cascade='all, delete-orphan', passive_deletes=True) + users_vacations = relationship('Vacation', back_populates='vacations_back_users', + cascade='all, delete-orphan', passive_deletes=True) + user_user_favorability = relationship('UserFavorability', back_populates='user_favorability_back_user', + cascade='all, delete-orphan', passive_deletes=True) + user_friend_sign_in = relationship('UserSignIn', back_populates='user_sign_in_back_user', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + qq: int, + nickname: str, + *, + aliasname: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.qq = qq self.nickname = nickname - self.is_friend = is_friend self.aliasname = aliasname self.created_at = created_at self.updated_at = updated_at def __repr__(self): return f"" + f"created_at='{self.created_at}', updated_at='{self.updated_at}')>" # 好友表 class Friends(Base): - __tablename__ = f'{TABLE_PREFIX}friends' + __tablename__ = f'{TABLE_PREFIX}bot_friends' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} # 表结构 - id = Column(Integer, Sequence('friends_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) + id = Column(Integer, Sequence('bot_friends_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) + bot_self_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bots.id', ondelete='CASCADE'), nullable=False) nickname = Column(String(64), nullable=False, comment='昵称') remark = Column(String(64), nullable=True, comment='备注') private_permissions = Column(Integer, nullable=False, comment='是否启用私聊权限') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - user_friend = relationship('User', back_populates='has_friends') - - def __init__(self, user_id, nickname, remark=None, private_permissions=0, created_at=None, updated_at=None): + # 设置级联和关系加载 + bot_friends_back_bots = relationship(BotSelf, back_populates='bots_bot_friends', lazy='joined', innerjoin=True) + bot_friends_back_users = relationship(User, back_populates='users_bot_friends', lazy='joined', innerjoin=True) + bot_friends_auth_user = relationship('AuthUser', back_populates='auth_user_back_bot_friends', + cascade='all, delete-orphan', passive_deletes=True) + bot_friends_users_subs = relationship('UserSub', back_populates='users_subs_back_bot_friends', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + user_id: int, + bot_self_id: int, + nickname: str, + private_permissions: int = 0, + *, + remark: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.user_id = user_id + self.bot_self_id = bot_self_id self.nickname = nickname self.remark = remark self.private_permissions = private_permissions @@ -103,11 +168,95 @@ def __init__(self, user_id, nickname, remark=None, private_permissions=0, create self.updated_at = updated_at def __repr__(self): - return f"" +# 好感度及状态表, 养成系统基础表单 +class UserFavorability(Base): + __tablename__ = f'{TABLE_PREFIX}user_favorability' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + id = Column(Integer, Sequence('user_fav_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) + status = Column(String(64), nullable=False, comment='当前状态') + mood = Column(Float, nullable=False, comment='情绪值, 大于0: 好心情, 小于零: 坏心情') + favorability = Column(Float, nullable=False, comment='好感度, 大于0: 友善, 小于0: 敌对') + energy = Column(Float, nullable=False, comment='能量值') + currency = Column(Float, nullable=False, comment='持有货币') + response_threshold = Column(Float, nullable=False, comment='响应阈值, 控制对交互做出响应的概率或频率, 根据具体插件使用数值') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + # 设置级联和关系加载 + user_favorability_back_user = relationship(User, back_populates='user_user_favorability', + lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + status: str, + mood: float, + favorability: float, + energy: float, + currency: float, + response_threshold: float, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.user_id = user_id + self.status = status + self.mood = mood + self.favorability = favorability + self.energy = energy + self.currency = currency + self.response_threshold = response_threshold + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + +# 签到表, 养成系统基础表单 +class UserSignIn(Base): + __tablename__ = f'{TABLE_PREFIX}user_sign_in' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + id = Column(Integer, Sequence('user_sign_in_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) + sign_in_date = Column(Date, nullable=False, comment='签到日期') + sign_in_info = Column(String(64), nullable=True, comment='签到信息') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + # 设置级联和关系加载 + user_sign_in_back_user = relationship(User, back_populates='user_friend_sign_in', + lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + sign_in_date: date_, + *, + sign_in_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.user_id = user_id + self.sign_in_date = sign_in_date + self.sign_in_info = sign_in_info + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + # 技能表 class Skill(Base): __tablename__ = f'{TABLE_PREFIX}skills' @@ -119,10 +268,16 @@ class Skill(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - avaiable_skills = relationship('UserSkill', back_populates='skill_used', - cascade="all, delete", passive_deletes=True) + # 设置级联和关系加载 + skills_users_skills = relationship('UserSkill', back_populates='users_skills_back_skills', + cascade='all, delete-orphan', passive_deletes=True) - def __init__(self, name, description=None, created_at=None, updated_at=None): + def __init__(self, + name: str, + *, + description: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.name = name self.description = description self.created_at = created_at @@ -139,16 +294,23 @@ class UserSkill(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('users_skills_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) - skill_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}skills.id'), nullable=False) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) + skill_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}skills.id', ondelete='CASCADE'), nullable=False) skill_level = Column(Integer, nullable=False, comment='技能等级') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - user_skill = relationship('User', back_populates='has_skills') - skill_used = relationship('Skill', back_populates='avaiable_skills') - - def __init__(self, user_id, skill_id, skill_level, created_at=None, updated_at=None): + # 设置级联和关系加载 + users_skills_back_users = relationship(User, back_populates='users_users_skills', lazy='joined', innerjoin=True) + users_skills_back_skills = relationship(Skill, back_populates='skills_users_skills', lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + skill_id: int, + skill_level: int, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.user_id = user_id self.skill_id = skill_id self.skill_level = skill_level @@ -168,25 +330,71 @@ class Group(Base): id = Column(Integer, Sequence('groups_id_seq'), primary_key=True, nullable=False, index=True, unique=True) name = Column(String(64), nullable=False, comment='qq群名称') group_id = Column(BigInteger, nullable=False, index=True, unique=True, comment='qq群号') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + # 设置级联和关系加载 + groups_bot_groups = relationship('BotGroup', back_populates='bot_groups_back_groups', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + name: str, + group_id: int, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.name = name + self.group_id = group_id + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + +# Bot对应qq群表 +class BotGroup(Base): + __tablename__ = f'{TABLE_PREFIX}bot_groups' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + id = Column(Integer, Sequence('bot_groups_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}groups.id', ondelete='CASCADE'), nullable=False) + bot_self_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bots.id', ondelete='CASCADE'), nullable=False) + group_memo = Column(String(64), nullable=True, comment='群备注') notice_permissions = Column(Integer, nullable=False, comment='通知权限') command_permissions = Column(Integer, nullable=False, comment='命令权限') permission_level = Column(Integer, nullable=False, comment='权限等级, 越大越高') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - avaiable_groups = relationship('UserGroup', back_populates='groups_have_users', - cascade="all, delete", passive_deletes=True) - sub_what = relationship('GroupSub', back_populates='groups_sub', - cascade="all, delete", passive_deletes=True) - group_auth = relationship('AuthGroup', back_populates='auth_for_group', uselist=False, - cascade="all, delete", passive_deletes=True) - group_box = relationship('GroupEmailBox', back_populates='box_for_group', - cascade="all, delete", passive_deletes=True) - - def __init__(self, name, group_id, notice_permissions, command_permissions, - permission_level, created_at=None, updated_at=None): - self.name = name + # 设置级联和关系加载 + bot_groups_back_bots = relationship(BotSelf, back_populates='bots_bot_groups', lazy='joined', innerjoin=True) + bot_groups_back_groups = relationship(Group, back_populates='groups_bot_groups', lazy='joined', innerjoin=True) + bot_groups_users_groups = relationship('UserGroup', back_populates='users_groups_back_bot_groups', + cascade='all, delete-orphan', passive_deletes=True) + bot_groups_groups_subs = relationship('GroupSub', back_populates='groups_subs_back_bot_groups', + cascade='all, delete-orphan', passive_deletes=True) + bot_groups_auth_group = relationship('AuthGroup', back_populates='auth_group_back_bot_groups', + cascade='all, delete-orphan', passive_deletes=True) + bot_groups_group_email_box = relationship('GroupEmailBox', back_populates='group_email_box_back_bot_groups', + cascade='all, delete-orphan', passive_deletes=True) + bot_groups_groups_settings = relationship('GroupSetting', back_populates='groups_settings_back_bot_groups', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + group_id: int, + bot_self_id: int, + notice_permissions: int, + command_permissions: int, + permission_level: int, + *, + group_memo: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.group_id = group_id + self.bot_self_id = bot_self_id + self.group_memo = group_memo self.notice_permissions = notice_permissions self.command_permissions = command_permissions self.permission_level = permission_level @@ -194,9 +402,9 @@ def __init__(self, name, group_id, notice_permissions, command_permissions, self.updated_at = updated_at def __repr__(self): - return f"" @@ -206,16 +414,24 @@ class UserGroup(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('users_groups_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) - group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}groups.id'), nullable=False) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_groups.id', ondelete='CASCADE'), nullable=False) user_group_nickname = Column(String(64), nullable=True, comment='用户群昵称') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - user_groups = relationship('User', back_populates='in_which_groups') - groups_have_users = relationship('Group', back_populates='avaiable_groups') - - def __init__(self, user_id, group_id, user_group_nickname=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + users_groups_back_users = relationship(User, back_populates='users_users_groups', lazy='joined', innerjoin=True) + users_groups_back_bot_groups = relationship(BotGroup, back_populates='bot_groups_users_groups', + lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + group_id: int, + *, + user_group_nickname: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.user_id = user_id self.group_id = group_id self.user_group_nickname = user_group_nickname @@ -234,7 +450,7 @@ class AuthUser(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('auth_user_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_friends.id', ondelete='CASCADE'), nullable=False) auth_node = Column(String(128), nullable=False, index=True, comment='授权节点, 由插件检查') allow_tag = Column(Integer, nullable=False, comment='授权标签') deny_tag = Column(Integer, nullable=False, comment='拒绝标签') @@ -242,9 +458,19 @@ class AuthUser(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - auth_for_user = relationship('User', back_populates='user_auth') - - def __init__(self, user_id, auth_node, allow_tag=0, deny_tag=0, auth_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + auth_user_back_bot_friends = relationship(Friends, back_populates='bot_friends_auth_user', + lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + auth_node: str, + allow_tag: int = 0, + deny_tag: int = 0, + *, + auth_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.user_id = user_id self.auth_node = auth_node self.allow_tag = allow_tag @@ -265,7 +491,7 @@ class AuthGroup(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('auth_group_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}groups.id'), nullable=False) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_groups.id', ondelete='CASCADE'), nullable=False) auth_node = Column(String(128), nullable=False, index=True, comment='授权节点, 由插件检查') allow_tag = Column(Integer, nullable=False, comment='授权标签') deny_tag = Column(Integer, nullable=False, comment='拒绝标签') @@ -273,9 +499,19 @@ class AuthGroup(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - auth_for_group = relationship('Group', back_populates='group_auth') - - def __init__(self, group_id, auth_node, allow_tag=0, deny_tag=0, auth_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + auth_group_back_bot_groups = relationship(BotGroup, back_populates='bot_groups_auth_group', + lazy='joined', innerjoin=True) + + def __init__(self, + group_id: int, + auth_node: str, + allow_tag: int = 0, + deny_tag: int = 0, + *, + auth_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.group_id = group_id self.auth_node = auth_node self.allow_tag = allow_tag @@ -290,6 +526,52 @@ def __repr__(self): f"created_at='{self.created_at}', updated_at='{self.updated_at}')>" +# 群组设置信息表, 用于群管/定时任务/欢迎各种需要持久化群设定信息的插件 +class GroupSetting(Base): + __tablename__ = f'{TABLE_PREFIX}groups_settings' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + id = Column(Integer, Sequence('groups_settings_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_groups.id', ondelete='CASCADE'), nullable=False) + setting_name = Column(String(128), nullable=False, index=True, comment='配置项名称') + main_config = Column(String(128), nullable=False, index=True, comment='主要配置') + secondary_config = Column(String(128), nullable=False, index=True, comment='次要配置, 用于需要多个配置项的情况') + extra_config = Column(String(8192), nullable=False, comment='额外配置, 用于存放无需索引的超长数据') + setting_info = Column(String(128), nullable=True, comment='配置信息') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + # 设置级联和关系加载 + groups_settings_back_bot_groups = relationship(BotGroup, back_populates='bot_groups_groups_settings', + lazy='joined', innerjoin=True) + + def __init__( + self, + group_id: int, + setting_name: str, + main_config: str, + *, + secondary_config: str = 'None', + extra_config: str = 'None', + setting_info: Optional[str] = 'None', + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.group_id = group_id + self.setting_name = setting_name + self.main_config = main_config + self.secondary_config = secondary_config + self.extra_config = extra_config + self.setting_info = setting_info + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + # 邮箱表 class EmailBox(Base): __tablename__ = f'{TABLE_PREFIX}email_box' @@ -304,11 +586,19 @@ class EmailBox(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - used_box = relationship('GroupEmailBox', back_populates='has_box', - cascade="all, delete", passive_deletes=True) - - def __init__(self, address: str, server_host: str, password: str, - protocol: str = 'imap', port: int = 993, created_at=None, updated_at=None): + # 设置级联和关系加载 + email_box_group_email_box = relationship('GroupEmailBox', back_populates='group_email_box_back_email_box', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + address: str, + server_host: str, + password: str, + *, + protocol: str = 'imap', + port: int = 993, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.address = address self.server_host = server_host self.protocol = protocol @@ -329,17 +619,25 @@ class GroupEmailBox(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('group_email_box_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - email_box_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}email_box.id'), nullable=False) - group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}groups.id'), nullable=False) + email_box_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}email_box.id', ondelete='CASCADE'), nullable=False) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_groups.id', ondelete='CASCADE'), nullable=False) box_info = Column(String(64), nullable=True, comment='群邮箱信息,暂空备用') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - box_for_group = relationship('Group', back_populates='group_box') - - has_box = relationship('EmailBox', back_populates='used_box') - - def __init__(self, email_box_id, group_id, box_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + group_email_box_back_bot_groups = relationship(BotGroup, back_populates='bot_groups_group_email_box', + lazy='joined', innerjoin=True) + group_email_box_back_email_box = relationship(EmailBox, back_populates='email_box_group_email_box', + lazy='joined', innerjoin=True) + + def __init__(self, + email_box_id: int, + group_id: int, + *, + box_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.email_box_id = email_box_id self.group_id = group_id self.box_info = box_info @@ -368,7 +666,17 @@ class Email(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, mail_hash, date, header, sender, to, body, html, created_at=None, updated_at=None): + def __init__(self, + mail_hash: str, + date: str, + header: str, + sender: str, + to: str, + *, + body: Optional[str] = 'Null', + html: Optional[str] = 'Null', + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.mail_hash = mail_hash self.date = date self.header = header @@ -394,22 +702,34 @@ class History(Base): # 表结构 id = Column(Integer, Sequence('history_id_seq'), primary_key=True, nullable=False, index=True, unique=True) time = Column(BigInteger, nullable=False, comment='事件发生的时间戳') - self_id = Column(BigInteger, nullable=False, comment='收到事件的机器人QQ号') - post_type = Column(String(64), nullable=False, comment='事件类型') - detail_type = Column(String(64), nullable=False, comment='消息/通知/请求/元事件类型') - sub_type = Column(String(64), nullable=True, comment='子事件类型') - event_id = Column(BigInteger, nullable=True, comment='事件id, 消息事件为message_id') - group_id = Column(BigInteger, nullable=True, comment='群号') - user_id = Column(BigInteger, nullable=True, comment='发送者QQ号') + self_id = Column(BigInteger, nullable=False, index=True, comment='收到事件的机器人QQ号') + post_type = Column(String(64), nullable=False, index=True, comment='事件类型') + detail_type = Column(String(64), nullable=False, index=True, comment='消息/通知/请求/元事件类型') + sub_type = Column(String(64), nullable=False, index=True, comment='子事件类型') + event_id = Column(BigInteger, nullable=False, index=True, comment='事件id, 消息事件为message_id') + group_id = Column(BigInteger, nullable=False, index=True, comment='群号') + user_id = Column(BigInteger, nullable=False, index=True, comment='发送者QQ号') user_name = Column(String(64), nullable=True, comment='发送者名称') raw_data = Column(String(4096), nullable=True, comment='原始事件内容') msg_data = Column(String(4096), nullable=True, comment='经处理的事件内容') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, time, self_id, post_type, detail_type, sub_type=None, event_id=None, - group_id=None, user_id=None, user_name=None, raw_data=None, msg_data=None, - created_at=None, updated_at=None): + def __init__(self, + time: int, + self_id: int, + post_type: str, + detail_type: str, + sub_type: str = 'Undefined', + event_id: int = 0, + group_id: int = -1, + user_id: int = -1, + *, + user_name: Optional[str] = None, + raw_data: Optional[str] = None, + msg_data: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.time = time self.self_id = self_id self.post_type = post_type @@ -434,22 +754,38 @@ def __repr__(self): # 订阅表 class Subscription(Base): + """sub_type 订阅类型: + 0-暂留 + 1-B站直播间 + 2-B站动态 + 8-Pixivsion特辑 + 9-Pixiv画师 + """ __tablename__ = f'{TABLE_PREFIX}subscription' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('subscription_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - # 订阅类型, 0暂留, 1直播间, 2动态, 8Pixivsion - sub_type = Column(Integer, nullable=False, comment='订阅类型,0暂留,1直播间,2动态') + sub_type = Column(Integer, nullable=False, comment='订阅类型') sub_id = Column(Integer, nullable=False, index=True, comment='订阅id,直播为直播间房间号,动态为用户uid') up_name = Column(String(64), nullable=False, comment='up名称') live_info = Column(String(64), nullable=True, comment='相关信息,暂空备用') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - be_sub = relationship('GroupSub', back_populates='sub_by', cascade="all, delete", passive_deletes=True) - be_sub_users = relationship('UserSub', back_populates='sub_by_users', cascade="all, delete", passive_deletes=True) - - def __init__(self, sub_type, sub_id, up_name, live_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + subscription_groups_subs = relationship('GroupSub', back_populates='groups_subs_back_subscription', + cascade='all, delete-orphan', passive_deletes=True) + subscription_users_subs = relationship('UserSub', back_populates='users_subs_back_subscription', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + sub_type: int, + sub_id: int, + up_name: str, + *, + live_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.sub_type = sub_type self.sub_id = sub_id self.up_name = up_name @@ -468,16 +804,25 @@ class GroupSub(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('groups_subs_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - sub_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}subscription.id'), nullable=False) - group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}groups.id'), nullable=False) + sub_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}subscription.id', ondelete='CASCADE'), nullable=False) + group_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_groups.id', ondelete='CASCADE'), nullable=False) group_sub_info = Column(String(64), nullable=True, comment='群订阅信息,暂空备用') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - groups_sub = relationship('Group', back_populates='sub_what') - sub_by = relationship('Subscription', back_populates='be_sub') - - def __init__(self, sub_id, group_id, group_sub_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + groups_subs_back_bot_groups = relationship(BotGroup, back_populates='bot_groups_groups_subs', + lazy='joined', innerjoin=True) + groups_subs_back_subscription = relationship(Subscription, back_populates='subscription_groups_subs', + lazy='joined', innerjoin=True) + + def __init__(self, + sub_id: int, + group_id: int, + *, + group_sub_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.sub_id = sub_id self.group_id = group_id self.group_sub_info = group_sub_info @@ -496,16 +841,25 @@ class UserSub(Base): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} id = Column(Integer, Sequence('users_subs_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - sub_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}subscription.id'), nullable=False) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) + sub_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}subscription.id', ondelete='CASCADE'), nullable=False) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}bot_friends.id', ondelete='CASCADE'), nullable=False) user_sub_info = Column(String(64), nullable=True, comment='用户订阅信息,暂空备用') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - users_sub = relationship('User', back_populates='users_sub_what') - sub_by_users = relationship('Subscription', back_populates='be_sub_users') - - def __init__(self, sub_id, user_id, user_sub_info=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + users_subs_back_bot_friends = relationship(Friends, back_populates='bot_friends_users_subs', + lazy='joined', innerjoin=True) + users_subs_back_subscription = relationship(Subscription, back_populates='subscription_users_subs', + lazy='joined', innerjoin=True) + + def __init__(self, + sub_id: int, + user_id: int, + *, + user_sub_info: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.sub_id = sub_id self.user_id = user_id self.user_sub_info = user_sub_info @@ -532,7 +886,14 @@ class Bilidynamic(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, uid, dynamic_id, dynamic_type, content, created_at=None, updated_at=None): + def __init__(self, + uid: int, + dynamic_id: int, + dynamic_type: int, + content: str, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.uid = uid self.dynamic_id = dynamic_id self.dynamic_type = dynamic_type @@ -547,21 +908,29 @@ def __repr__(self): # 假期表 -class Vocation(Base): - __tablename__ = f'{TABLE_PREFIX}vocations' +class Vacation(Base): + __tablename__ = f'{TABLE_PREFIX}vacations' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} - id = Column(Integer, Sequence('vocations_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id'), nullable=False) + id = Column(Integer, Sequence('vacations_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + user_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}users.id', ondelete='CASCADE'), nullable=False) status = Column(Integer, nullable=False, comment='请假状态 0-空闲 1-请假 2-工作中') stop_at = Column(DateTime, nullable=True, comment='假期结束时间') reason = Column(String(64), nullable=True, comment='请假理由') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - vocation_for_user = relationship('User', back_populates='vocation') - - def __init__(self, user_id, status, stop_at=None, reason=None, created_at=None, updated_at=None): + # 设置级联和关系加载 + vacations_back_users = relationship(User, back_populates='users_vacations', lazy='joined', innerjoin=True) + + def __init__(self, + user_id: int, + status: int, + *, + stop_at: Optional[datetime] = None, + reason: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.user_id = user_id self.status = status self.stop_at = stop_at @@ -570,33 +939,10 @@ def __init__(self, user_id, status, stop_at=None, reason=None, created_at=None, self.updated_at = updated_at def __repr__(self): - return f"" -# Pixiv tag表 -class PixivTag(Base): - __tablename__ = f'{TABLE_PREFIX}pixiv_tag' - __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} - - id = Column(Integer, Sequence('pixiv_tag_id_seq'), primary_key=True, nullable=False, index=True, unique=True) - tagname = Column(String(128), nullable=False, index=True, unique=True, comment='tag名称') - created_at = Column(DateTime, nullable=True) - updated_at = Column(DateTime, nullable=True) - - has_illusts = relationship('PixivT2I', back_populates='tag_has_illusts', - cascade="all, delete", passive_deletes=True) - - def __init__(self, tagname, created_at=None, updated_at=None): - self.tagname = tagname - self.created_at = created_at - self.updated_at = updated_at - - def __repr__(self): - return f"" - - # Pixiv作品表 class Pixiv(Base): __tablename__ = f'{TABLE_PREFIX}pixiv_illusts' @@ -608,21 +954,38 @@ class Pixiv(Base): uid = Column(Integer, nullable=False, index=True, comment='uid') title = Column(String(128), nullable=False, index=True, comment='title') uname = Column(String(128), nullable=False, index=True, comment='author') - nsfw_tag = Column(Integer, nullable=False, comment='nsfw标签, 0=safe, 1=setu. 2=r18') + nsfw_tag = Column(Integer, nullable=False, index=True, comment='nsfw标签, 0=safe, 1=setu. 2=r18') + width = Column(Integer, nullable=False, comment='原始图片宽度') + height = Column(Integer, nullable=False, comment='原始图片高度') tags = Column(String(1024), nullable=False, comment='tags') url = Column(String(1024), nullable=False, comment='url') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - has_tags = relationship('PixivT2I', back_populates='illust_tags', - cascade="all, delete", passive_deletes=True) - - def __init__(self, pid, uid, title, uname, nsfw_tag, tags, url, created_at=None, updated_at=None): + # 设置级联和关系加载 + pixiv_illusts_pixiv_pages = relationship('PixivPage', back_populates='pixiv_pages_back_pixiv_illusts', + cascade='all, delete-orphan', passive_deletes=True) + + def __init__(self, + pid: int, + uid: int, + title: str, + uname: str, + nsfw_tag: int, + tags: str, + url: str, + *, + width: int = 0, + height: int = 0, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.pid = pid self.uid = uid self.title = title self.uname = uname self.nsfw_tag = nsfw_tag + self.width = width + self.height = height self.tags = tags self.url = url self.created_at = created_at @@ -630,33 +993,52 @@ def __init__(self, pid, uid, title, uname, nsfw_tag, tags, url, created_at=None, def __repr__(self): return f"" -# Pixiv作品-tag表 -class PixivT2I(Base): - __tablename__ = f'{TABLE_PREFIX}pixiv_tag_to_illusts' +# Pixiv作品图片表 +class PixivPage(Base): + __tablename__ = f'{TABLE_PREFIX}pixiv_pages' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} - id = Column(Integer, Sequence('pixiv_tag_to_illusts_id_seq'), - primary_key=True, nullable=False, index=True, unique=True) - illust_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}pixiv_illusts.id'), nullable=False) - tag_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}pixiv_tag.id'), nullable=False) + id = Column(Integer, Sequence('pixiv_pages_id_seq'), primary_key=True, nullable=False, index=True, unique=True) + illust_id = Column(Integer, ForeignKey(f'{TABLE_PREFIX}pixiv_illusts.id', ondelete='CASCADE'), nullable=False) + page = Column(Integer, nullable=False, index=True, comment='页码') + original = Column(String(1024), nullable=False, comment='original image url') + regular = Column(String(1024), nullable=False, comment='regular image url') + small = Column(String(1024), nullable=False, comment='small image url') + thumb_mini = Column(String(1024), nullable=False, comment='thumb_mini image url') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - illust_tags = relationship('Pixiv', back_populates='has_tags') - tag_has_illusts = relationship('PixivTag', back_populates='has_illusts') - - def __init__(self, illust_id, tag_id, created_at=None, updated_at=None): + # 设置级联和关系加载 + pixiv_pages_back_pixiv_illusts = relationship(Pixiv, back_populates='pixiv_illusts_pixiv_pages', + lazy='joined', innerjoin=True) + + def __init__(self, + illust_id: int, + page: int, + original: str, + regular: str, + small: str, + thumb_mini: str, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.illust_id = illust_id - self.tag_id = tag_id + self.page = page + self.original = original + self.regular = regular + self.small = small + self.thumb_mini = thumb_mini self.created_at = created_at self.updated_at = updated_at def __repr__(self): - return f"" @@ -672,12 +1054,21 @@ class Pixivision(Base): title = Column(String(256), nullable=False, comment='title') description = Column(String(1024), nullable=False, comment='description') tags = Column(String(1024), nullable=False, comment='tags') - illust_id = Column(String(1024), nullable=False, comment='tags') + illust_id = Column(String(1024), nullable=False, comment='article illust_id') url = Column(String(1024), nullable=False, comment='url') created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, aid, title, description, tags, illust_id, url, created_at=None, updated_at=None): + def __init__(self, + aid: int, + title: str, + description: str, + tags: str, + illust_id: str, + url: str, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.aid = aid self.title = title self.description = description @@ -693,6 +1084,42 @@ def __repr__(self): f"created_at='{self.created_at}', updated_at='{self.updated_at}')>" +# Pixiv用户作品表, 用于P站订阅插件 +# 因画师作品内容不定,将不与萌图/涩图插件共用pixiv_illust表, 避免混入奇怪的东西 +class PixivUserArtwork(Base): + __tablename__ = f'{TABLE_PREFIX}pixiv_users_artworks' + __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} + + # 表结构 + id = Column(Integer, Sequence('pixiv_users_artworks_id_seq'), + primary_key=True, nullable=False, index=True, unique=True) + pid = Column(Integer, nullable=False, index=True, unique=True, comment='pid') + uid = Column(Integer, nullable=False, index=True, comment='uid') + uname = Column(String(128), nullable=False, index=True, comment='author') + title = Column(String(128), nullable=False, index=True, comment='title') + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + + def __init__(self, + pid: int, + uid: int, + uname: str, + title: str, + *, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): + self.pid = pid + self.uid = uid + self.uname = uname + self.title = title + self.created_at = created_at + self.updated_at = updated_at + + def __repr__(self): + return f"" + + # 冷却事件表 class CoolDownEvent(Base): __tablename__ = f'{TABLE_PREFIX}cool_down_event' @@ -710,8 +1137,16 @@ class CoolDownEvent(Base): created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) - def __init__(self, event_type, stop_at, plugin=None, group_id=None, user_id=None, description=None, - created_at=None, updated_at=None): + def __init__(self, + event_type: str, + stop_at: datetime, + *, + plugin: Optional[str] = None, + group_id: Optional[int] = None, + user_id: Optional[int] = None, + description: Optional[str] = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None): self.event_type = event_type self.stop_at = stop_at self.plugin = plugin diff --git a/omega_miya/utils/Omega_history/__init__.py b/omega_miya/utils/Omega_history/__init__.py deleted file mode 100644 index 19b8a33d..00000000 --- a/omega_miya/utils/Omega_history/__init__.py +++ /dev/null @@ -1,115 +0,0 @@ -from nonebot import MatcherGroup, on_message, on_request, on_notice, logger -from nonebot.plugin import on -from nonebot.typing import T_State -from nonebot.adapters.cqhttp.bot import Bot -from nonebot.adapters.cqhttp.event import Event -from omega_miya.utils.Omega_Base import DBHistory - - -# 注册事件响应器, 处理MessageEvent -Message_history = MatcherGroup(type='message', priority=101, block=True) - -message_history = Message_history.on_message() - - -@message_history.handle() -async def handle_message(bot: Bot, event: Event, state: T_State): - try: - message_id = event.dict().get('message_id') - user_name = event.dict().get('sender').get('card') - if not user_name: - user_name = event.dict().get('sender').get('nickname') - time = event.dict().get('time') - self_id = event.dict().get('self_id') - post_type = event.get_type() - detail_type = event.dict().get(f'{event.get_type()}_type') - sub_type = event.dict().get('sub_type') - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - raw_data = repr(event) - msg_data = str(event.dict().get('message')) - new_event = DBHistory(time=time, self_id=self_id, post_type=post_type, detail_type=detail_type) - res = await new_event.add(sub_type=sub_type, event_id=message_id, group_id=group_id, user_id=user_id, - user_name=user_name, raw_data=raw_data, msg_data=msg_data) - if res.error: - logger.error(f'Message history recording Failed with database error: {res.info}') - except Exception as e: - logger.error(f'Message history recording Failed, error: {repr(e)}') - - -# 注册事件响应器, 处理message_sent -message_sent_history = on(type='message_sent', priority=101, block=True) - - -@message_sent_history.handle() -async def handle_message_sent_history(bot: Bot, event: Event, state: T_State): - try: - user_name = event.dict().get('sender').get('card') - if not user_name: - user_name = event.dict().get('sender').get('nickname') - time = event.dict().get('time') - self_id = event.dict().get('self_id') - post_type = event.get_type() - detail_type = 'self_sent' - sub_type = 'self' - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - raw_data = repr(event) - msg_data = str(event.dict().get('message')) - new_event = DBHistory(time=time, self_id=self_id, post_type=post_type, detail_type=detail_type) - res = await new_event.add(sub_type=sub_type, group_id=group_id, user_id=user_id, user_name=user_name, - raw_data=raw_data, msg_data=msg_data) - if res.error: - logger.error(f'Self-sent Message history recording Failed with database error: {res.info}') - except Exception as e: - logger.error(f'Self-sent Message history recording Failed, error: {repr(e)}') - - -# 注册事件响应器, 处理NoticeEvent -notice_history = on_notice(priority=101, block=True) - - -@notice_history.handle() -async def handle_notice(bot: Bot, event: Event, state: T_State): - try: - time = event.dict().get('time') - self_id = event.dict().get('self_id') - post_type = event.get_type() - detail_type = event.dict().get(f'{event.get_type()}_type') - sub_type = event.dict().get('sub_type') - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - raw_data = repr(event) - msg_data = str(event.dict().get('message')) - new_event = DBHistory(time=time, self_id=self_id, post_type=post_type, detail_type=detail_type) - res = await new_event.add(sub_type=sub_type, group_id=group_id, user_id=user_id, user_name=None, - raw_data=raw_data, msg_data=msg_data) - if res.error: - logger.error(f'Notice history recording Failed with database error: {res.info}') - except Exception as e: - logger.error(f'Notice history recording Failed, error: {repr(e)}') - - -# 注册事件响应器, 处理RequestEvent -request_history = on_request(priority=101, block=True) - - -@request_history.handle() -async def handle_request(bot: Bot, event: Event, state: T_State): - try: - time = event.dict().get('time') - self_id = event.dict().get('self_id') - post_type = event.get_type() - detail_type = event.dict().get(f'{event.get_type()}_type') - sub_type = event.dict().get('sub_type') - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - raw_data = repr(event) - msg_data = str(event.dict().get('message')) - new_event = DBHistory(time=time, self_id=self_id, post_type=post_type, detail_type=detail_type) - res = await new_event.add(sub_type=sub_type, group_id=group_id, user_id=user_id, user_name=None, - raw_data=raw_data, msg_data=msg_data) - if res.error: - logger.error(f'Request history recording Failed with database error: {res.info}') - except Exception as e: - logger.error(f'Request history recording Failed, error: {repr(e)}') diff --git a/omega_miya/utils/Omega_multibot_support/__init__.py b/omega_miya/utils/Omega_multibot_support/__init__.py new file mode 100644 index 00000000..334500ef --- /dev/null +++ b/omega_miya/utils/Omega_multibot_support/__init__.py @@ -0,0 +1,65 @@ +""" +@Author : Ailitonia +@Date : 2021/05/23 19:40 +@FileName : __init__.py +@Project : nonebot2_miya +@Description : Multi-Bot 多协议端接入支持 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from typing import Dict +from nonebot import get_driver, logger +from nonebot.typing import T_State +from nonebot.matcher import Matcher +from nonebot.adapters.cqhttp.event import Event +from nonebot.message import run_preprocessor +from nonebot.exception import IgnoredException +from nonebot.adapters.cqhttp.bot import Bot +from omega_miya.utils.Omega_Base import DBBot + +driver = get_driver() +ONLINE_BOTS: Dict[str, Bot] = {} + + +@driver.on_bot_connect +async def upgrade_connected_bot(bot: Bot): + global ONLINE_BOTS + ONLINE_BOTS.update({bot.self_id: bot}) + # bot_info = await bot.get_login_info() + bot_info = await bot.get_version_info() + info = '||'.join([f'{k}:{v}' for (k, v) in bot_info.items()]) + bot_result = await DBBot(self_qq=int(bot.self_id)).upgrade(status=1, info=info) + if bot_result.success(): + logger.opt(colors=True).info(f'Bot: {bot.self_id} 已连接, ' + f'Database upgrade Success: {bot_result.info}') + else: + logger.opt(colors=True).error(f'Bot: {bot.self_id} 已连接, ' + f'Database upgrade Failed: {bot_result.info}') + + +@driver.on_bot_disconnect +async def upgrade_disconnected_bot(bot: Bot): + global ONLINE_BOTS + ONLINE_BOTS.pop(bot.self_id, None) + bot_result = await DBBot(self_qq=int(bot.self_id)).upgrade(status=0) + if bot_result.success(): + logger.opt(colors=True).warning(f'Bot: {bot.self_id} 已离线, ' + f'Database upgrade Success: {bot_result.info}') + else: + logger.opt(colors=True).error(f'Bot: {bot.self_id} 已离线, ' + f'Database upgrade Failed: {bot_result.info}') + + +@run_preprocessor +async def unique_bot_responding_limit(matcher: Matcher, bot: Bot, event: Event, state: T_State): + # 对于多协议端同时接入, 需匹配event.self_id与bot.self_id, 以保证会话不会被跨bot, 跨群, 跨用户触发 + if bot.self_id != str(event.self_id): + logger.debug(f'Bot {bot.self_id} ignored event which not match self_id with {event.self_id}.') + raise IgnoredException(f'Bot {bot.self_id} ignored event which not match self_id with {event.self_id}.') + + # 对于多协议端同时接入, 各个bot之间不能相互响应, 避免形成死循环 + event_user_id = str(event.dict().get('user_id')) + if event_user_id in [x for x in ONLINE_BOTS.keys() if x != bot.self_id]: + logger.debug(f'Bot {bot.self_id} ignored responding self-relation event with Bot {event_user_id}.') + raise IgnoredException(f'Bot {bot.self_id} ignored responding self-relation event with Bot {event_user_id}.') diff --git a/omega_miya/utils/Omega_plugin_utils/__init__.py b/omega_miya/utils/Omega_plugin_utils/__init__.py index 674f88d8..1dcf2dbe 100644 --- a/omega_miya/utils/Omega_plugin_utils/__init__.py +++ b/omega_miya/utils/Omega_plugin_utils/__init__.py @@ -1,12 +1,15 @@ from typing import Optional from nonebot.plugin import Export from nonebot.typing import T_State -from .rules import * +from .rules import OmegaRules from .encrypt import AESEncryptStr -from .cooldown import * -from .permission import * +from .cooldown import PluginCoolDown +from .permission import PermissionChecker from .http_fetcher import HttpFetcher +from .message_sender import MsgSender from .picture_encoder import PicEncoder +from .picture_effector import PicEffector +from .process_utils import ProcessUtils from .zip_utils import create_zip_file, create_7z_file @@ -44,25 +47,15 @@ def init_permission_state( __all__ = [ 'init_export', 'init_permission_state', - 'has_notice_permission', - 'has_command_permission', - 'has_auth_node', - 'has_level_or_node', - 'permission_level', - 'has_friend_private_permission', + 'OmegaRules', 'AESEncryptStr', 'PluginCoolDown', - 'check_and_set_global_cool_down', - 'check_and_set_plugin_cool_down', - 'check_and_set_group_cool_down', - 'check_and_set_user_cool_down', - 'check_notice_permission', - 'check_command_permission', - 'check_permission_level', - 'check_auth_node', - 'check_friend_private_permission', + 'PermissionChecker', 'HttpFetcher', + 'MsgSender', 'PicEncoder', + 'PicEffector', + 'ProcessUtils', 'create_zip_file', 'create_7z_file' ] diff --git a/omega_miya/utils/Omega_plugin_utils/cooldown.py b/omega_miya/utils/Omega_plugin_utils/cooldown.py index aa63c0c9..a756ccea 100644 --- a/omega_miya/utils/Omega_plugin_utils/cooldown.py +++ b/omega_miya/utils/Omega_plugin_utils/cooldown.py @@ -13,67 +13,63 @@ class PluginCoolDown: user_type: str = field(default='user', init=False) skip_auth_node: str = field(default='skip_cd', init=False) - -async def check_and_set_global_cool_down(minutes: int) -> Result.IntResult: - check = await DBCoolDownEvent.check_global_cool_down_event() - if check.result == 1: - return check - elif check.result == 0: - if minutes <= 0: + @classmethod + async def check_and_set_global_cool_down(cls, minutes: int) -> Result.IntResult: + check = await DBCoolDownEvent.check_global_cool_down_event() + if check.result == 1: return check - await DBCoolDownEvent.add_global_cool_down_event( - stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes)) - return check - else: - return check - - -async def check_and_set_plugin_cool_down(minutes: int, plugin: str) -> Result.IntResult: - check = await DBCoolDownEvent.check_plugin_cool_down_event(plugin=plugin) - if check.result == 1: - return check - elif check.result == 0: - if minutes <= 0: + elif check.result in [0, 2]: + if minutes <= 0: + return check + await DBCoolDownEvent.add_global_cool_down_event( + stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes)) + return check + else: return check - await DBCoolDownEvent.add_plugin_cool_down_event( - stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin) - return check - else: - return check - -async def check_and_set_group_cool_down(minutes: int, plugin: str, group_id: int) -> Result.IntResult: - check = await DBCoolDownEvent.check_group_cool_down_event(plugin=plugin, group_id=group_id) - if check.result == 1: - return check - elif check.result == 0: - if minutes <= 0: + @classmethod + async def check_and_set_plugin_cool_down(cls, minutes: int, plugin: str) -> Result.IntResult: + check = await DBCoolDownEvent.check_plugin_cool_down_event(plugin=plugin) + if check.result == 1: + return check + elif check.result in [0, 2]: + if minutes <= 0: + return check + await DBCoolDownEvent.add_plugin_cool_down_event( + stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin) + return check + else: return check - await DBCoolDownEvent.add_group_cool_down_event( - stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin, group_id=group_id) - return check - else: - return check + @classmethod + async def check_and_set_group_cool_down(cls, minutes: int, plugin: str, group_id: int) -> Result.IntResult: + check = await DBCoolDownEvent.check_group_cool_down_event(plugin=plugin, group_id=group_id) + if check.result == 1: + return check + elif check.result in [0, 2]: + if minutes <= 0: + return check + await DBCoolDownEvent.add_group_cool_down_event( + stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin, group_id=group_id) + return check + else: + return check -async def check_and_set_user_cool_down(minutes: int, plugin: str, user_id: int) -> Result.IntResult: - check = await DBCoolDownEvent.check_user_cool_down_event(plugin=plugin, user_id=user_id) - if check.result == 1: - return check - elif check.result == 0: - if minutes <= 0: + @classmethod + async def check_and_set_user_cool_down(cls, minutes: int, plugin: str, user_id: int) -> Result.IntResult: + check = await DBCoolDownEvent.check_user_cool_down_event(plugin=plugin, user_id=user_id) + if check.result == 1: + return check + elif check.result in [0, 2]: + if minutes <= 0: + return check + await DBCoolDownEvent.add_user_cool_down_event( + stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin, user_id=user_id) + return check + else: return check - await DBCoolDownEvent.add_user_cool_down_event( - stop_at=datetime.datetime.now() + datetime.timedelta(minutes=minutes), plugin=plugin, user_id=user_id) - return check - else: - return check __all__ = [ - 'PluginCoolDown', - 'check_and_set_global_cool_down', - 'check_and_set_plugin_cool_down', - 'check_and_set_group_cool_down', - 'check_and_set_user_cool_down' + 'PluginCoolDown' ] diff --git a/omega_miya/utils/Omega_plugin_utils/http_fetcher.py b/omega_miya/utils/Omega_plugin_utils/http_fetcher.py index cdc33d31..204aac8e 100644 --- a/omega_miya/utils/Omega_plugin_utils/http_fetcher.py +++ b/omega_miya/utils/Omega_plugin_utils/http_fetcher.py @@ -2,9 +2,11 @@ import aiohttp import aiofiles import nonebot +from urllib.parse import urlparse +from http.cookies import SimpleCookie as SimpleCookie_ from asyncio.exceptions import TimeoutError as TimeoutError_ from dataclasses import dataclass -from typing import Dict, Union, Optional, Any +from typing import Dict, List, Union, Iterable, Optional, Any from nonebot import logger from omega_miya.utils.Omega_Base import DBStatus @@ -23,6 +25,7 @@ class __FetcherResult: info: str status: int headers: dict + cookies: Optional[SimpleCookie_] def success(self) -> bool: if not self.error: @@ -54,6 +57,31 @@ def __repr__(self): return f'' + @dataclass + class FormData(aiohttp.FormData): + def __init__( + self, + fields: Iterable[Any] = (), + *, + is_multipart: bool = False, + is_processed: bool = False, + quote_fields: bool = True, + charset: Optional[str] = None, + boundary: Optional[str] = None + ) -> None: + self._writer = aiohttp.multipart.MultipartWriter("form-data", boundary=boundary) + self._fields: List[Any] = [] + self._is_multipart = is_multipart + self._is_processed = is_processed + self._quote_fields = quote_fields + self._charset = charset + + if isinstance(fields, dict): + fields = list(fields.items()) + elif not isinstance(fields, (list, tuple)): + fields = (fields,) + self.add_fields(*fields) + @classmethod async def __get_proxy(cls, always_return_proxy: bool = False) -> Optional[str]: if always_return_proxy: @@ -93,8 +121,9 @@ async def download_file( self, url: str, path: str, - file_name: str, - params: Dict[str, str] = None, + *, + file_name: Optional[str] = None, + params: Optional[Dict[str, str]] = None, force_proxy: bool = False, **kwargs: Any) -> FetcherTextResult: """ @@ -111,7 +140,12 @@ async def download_file( folder_path = os.path.abspath(path) if not os.path.exists(folder_path): os.makedirs(folder_path) - file_path = os.path.abspath(os.path.join(folder_path, file_name)) + + if file_name: + file_path = os.path.abspath(os.path.join(folder_path, file_name)) + else: + file_name = os.path.basename(urlparse(url).path) if os.path.basename(urlparse(url).path) else str(hash(url)) + file_path = os.path.abspath(os.path.join(folder_path, file_name)) proxy = await self.__get_proxy(always_return_proxy=force_proxy) num_of_attempts = 0 @@ -126,10 +160,12 @@ async def download_file( file_bytes = await rp.read() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies async with aiofiles.open(file_path, 'wb') as f: await f.write(file_bytes) result = self.FetcherTextResult( - error=False, info='Success', status=status, headers=headers, result=file_path) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=file_path) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -147,7 +183,8 @@ async def download_file( f'Failed too many times in download_file.\n' f'url: {url}\nparams: {params}') return self.FetcherTextResult( - error=True, info='Failed too many times in download_file', status=-1, headers={}, result='') + error=True, info='Failed too many times in download_file', + status=-1, headers={}, cookies=None, result='') async def get_json( self, @@ -168,8 +205,10 @@ async def get_json( result_json = await rp.json() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherJsonResult( - error=False, info='Success', status=status, headers=headers, result=result_json) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_json) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -187,7 +226,8 @@ async def get_json( f'Failed too many times in get_json.\n' f'url: {url}\nparams: {params}') return self.FetcherJsonResult( - error=True, info='Failed too many times in get_json', status=-1, headers={}, result={}) + error=True, info='Failed too many times in get_json', + status=-1, headers={}, cookies=None, result={}) async def get_text( self, @@ -208,8 +248,10 @@ async def get_text( result_text = await rp.text() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherTextResult( - error=False, info='Success', status=status, headers=headers, result=result_text) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_text) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -227,7 +269,8 @@ async def get_text( f'Failed too many times in get_text.\n' f'url: {url}\nparams: {params}') return self.FetcherTextResult( - error=True, info='Failed too many times in get_text', status=-1, headers={}, result='') + error=True, info='Failed too many times in get_text', + status=-1, headers={}, cookies=None, result='') async def get_bytes( self, @@ -248,8 +291,10 @@ async def get_bytes( result_bytes = await rp.read() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherBytesResult( - error=False, info='Success', status=status, headers=headers, result=result_bytes) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_bytes) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -267,14 +312,15 @@ async def get_bytes( f'Failed too many times in get_bytes.\n' f'url: {url}\nparams: {params}') return self.FetcherBytesResult( - error=True, info='Failed too many times in get_bytes', status=-1, headers={}, result=b'') + error=True, info='Failed too many times in get_bytes', + status=-1, headers={}, cookies=None, result=b'') async def post_json( self, url: str, params: Dict[str, str] = None, json: Dict[str, Any] = None, - data: Dict[str, Any] = None, + data: Union[FormData, Dict[str, Any]] = None, force_proxy: bool = False, **kwargs: Any) -> FetcherJsonResult: proxy = await self.__get_proxy(always_return_proxy=force_proxy) @@ -290,8 +336,10 @@ async def post_json( result_json = await rp.json() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherJsonResult( - error=False, info='Success', status=status, headers=headers, result=result_json) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_json) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -309,14 +357,15 @@ async def post_json( f'Failed too many times in post_json.\n' f'url: {url}\nparams: {params}\njson: {json}\ndata: {data}') return self.FetcherJsonResult( - error=True, info='Failed too many times in post_json', status=-1, headers={}, result={}) + error=True, info='Failed too many times in post_json', + status=-1, headers={}, cookies=None, result={}) async def post_text( self, url: str, params: Dict[str, str] = None, json: Dict[str, Any] = None, - data: Dict[str, Any] = None, + data: Union[FormData, Dict[str, Any]] = None, force_proxy: bool = False, **kwargs: Any) -> FetcherTextResult: proxy = await self.__get_proxy(always_return_proxy=force_proxy) @@ -332,8 +381,10 @@ async def post_text( result_text = await rp.text() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherTextResult( - error=False, info='Success', status=status, headers=headers, result=result_text) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_text) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -351,14 +402,15 @@ async def post_text( f'Failed too many times in post_text.\n' f'url: {url}\nparams: {params}\njson: {json}\ndata: {data}') return self.FetcherTextResult( - error=True, info='Failed too many times in post_text', status=-1, headers={}, result='') + error=True, info='Failed too many times in post_text', + status=-1, headers={}, cookies=None, result='') async def post_bytes( self, url: str, params: Dict[str, str] = None, json: Dict[str, Any] = None, - data: Dict[str, Any] = None, + data: Union[FormData, Dict[str, Any]] = None, force_proxy: bool = False, **kwargs: Any) -> FetcherBytesResult: proxy = await self.__get_proxy(always_return_proxy=force_proxy) @@ -374,8 +426,10 @@ async def post_bytes( result_bytes = await rp.read() status = rp.status headers = dict(rp.headers) + cookies = rp.cookies result = self.FetcherBytesResult( - error=False, info='Success', status=status, headers=headers, result=result_bytes) + error=False, info='Success', + status=status, headers=headers, cookies=cookies, result=result_bytes) return result except TimeoutError_: logger.opt(colors=True).warning( @@ -393,4 +447,5 @@ async def post_bytes( f'Failed too many times in post_bytes.\n' f'url: {url}\nparams: {params}\njson: {json}\ndata: {data}') return self.FetcherBytesResult( - error=True, info='Failed too many times in post_bytes', status=-1, headers={}, result=b'') + error=True, info='Failed too many times in post_bytes', + status=-1, headers={}, cookies=None, result=b'') diff --git a/omega_miya/utils/Omega_plugin_utils/message_sender.py b/omega_miya/utils/Omega_plugin_utils/message_sender.py new file mode 100644 index 00000000..679eeebf --- /dev/null +++ b/omega_miya/utils/Omega_plugin_utils/message_sender.py @@ -0,0 +1,282 @@ +""" +@Author : Ailitonia +@Date : 2021/05/27 22:04 +@FileName : message_sender.py +@Project : nonebot2_miya +@Description : Bot Message Sender +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import logger +from typing import Optional, List, Union +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.message import Message, MessageSegment +from omega_miya.utils.Omega_Base import DBBot, DBBotGroup, DBFriend, DBSubscription + + +class MsgSender(object): + def __init__(self, bot: Bot, log_flag: Optional[str] = 'DefaultSender'): + self.bot = bot + self.self_bot = DBBot(self_qq=int(bot.self_id)) + self.log_flag = f'MsgSender/{log_flag}/Bot[{bot.self_id}]' + + async def safe_broadcast_groups_subscription( + self, subscription: DBSubscription, message: Union[str, Message, MessageSegment]): + """ + 向所有具有某个订阅且启用了通知权限 notice permission 的群组发送消息 + """ + # 获取所有需要通知的群组 + notice_group_res = await subscription.sub_group_list_by_notice_permission(self_bot=self.self_bot, + notice_permission=1) + if notice_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast message, ' + f'getting sub group list with notice permission failed, error: {notice_group_res.info}') + return + + for group_id in notice_group_res.result: + try: + await self.bot.send_group_msg(group_id=group_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast message ' + f'to group: {group_id} failed, error: {repr(e)}') + continue + + async def safe_broadcast_groups_subscription_node_custom( + self, subscription: DBSubscription, message_list: List[Union[str, Message, MessageSegment]], + *, + custom_nickname: str = 'Ωμεγα' + ): + """ + 向所有具有某个订阅且启用了通知权限 notice permission 的群组发送自定义转发消息节点 + 仅支持 cq-http + """ + # 获取所有需要通知的群组 + notice_group_res = await subscription.sub_group_list_by_notice_permission(self_bot=self.self_bot, + notice_permission=1) + if notice_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast node_custom message, ' + f'getting sub group list with notice permission failed, error: {notice_group_res.info}') + return + + # 构造自定义消息节点 + custom_user_id = self.bot.self_id + node_message = [] + for msg in message_list: + if not msg: + logger.opt(colors=True).warning( + f'{self.log_flag} | A None-type message in message_list.') + continue + node_message.append({ + "type": "node", + "data": { + "name": custom_nickname, + "uin": custom_user_id, + "content": msg + } + }) + + for group_id in notice_group_res.result: + try: + await self.bot.send_group_forward_msg(group_id=group_id, messages=node_message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast node_custom message ' + f'to group: {group_id} failed, error: {repr(e)}') + continue + + async def safe_send_group_node_custom( + self, group_id: int, message_list: List[Union[str, Message, MessageSegment]], + *, + custom_nickname: str = 'Ωμεγα' + ): + """ + 向某个群组发送自定义转发消息节点 + 仅支持 cq-http + """ + # 构造自定义消息节点 + custom_user_id = self.bot.self_id + node_message = [] + for msg in message_list: + if not msg: + logger.opt(colors=True).warning( + f'{self.log_flag} | A None-type message in message_list.') + continue + node_message.append({ + "type": "node", + "data": { + "name": custom_nickname, + "user_id": custom_user_id, + "uin": custom_user_id, + "content": msg + } + }) + + try: + await self.bot.send_group_forward_msg(group_id=group_id, messages=node_message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending node_custom message ' + f'to group: {group_id} failed, error: {repr(e)}') + + async def safe_broadcast_friends_subscription( + self, subscription: DBSubscription, message: Union[str, Message, MessageSegment]): + """ + 向所有具有某个订阅且启用了通知权限 notice permission 的好友发送消息 + """ + # 获取所有需要通知的好友 + notice_friends_res = await subscription.sub_user_list_by_private_permission(self_bot=self.self_bot, + private_permission=1) + if notice_friends_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast message, ' + f'getting sub friends list with private permission failed, error: {notice_friends_res.info}') + return + + for user_id in notice_friends_res.result: + try: + await self.bot.send_private_msg(user_id=user_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending subscription ' + f'{subscription.sub_type}/{subscription.sub_id} broadcast message ' + f'to user: {user_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_enabled_friends(self, message: Union[str, Message, MessageSegment]): + """ + 向所有具有好友权限 private permission (已启用bot命令) 的好友发送消息 + """ + # 获取所有启用 private permission 好友 + enabled_friends_res = await DBFriend.list_exist_friends_by_private_permission(self_bot=self.self_bot, + private_permission=1) + if enabled_friends_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send message to friends, ' + f'getting enabled friends list with private permission failed, error: {enabled_friends_res.info}') + return + + for user_id in enabled_friends_res.result: + try: + await self.bot.send_private_msg(user_id=user_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to friend: {user_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_all_friends(self, message: Union[str, Message, MessageSegment]): + """ + 向所有好友发送消息 + """ + # 获取数据库中所有好友 + all_friends_res = await DBFriend.list_exist_friends(self_bot=self.self_bot) + if all_friends_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send message to friends, ' + f'getting all friends list with private permission failed, error: {all_friends_res.info}') + return + + for user_id in all_friends_res.result: + try: + await self.bot.send_private_msg(user_id=user_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to friend: {user_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_enabled_command_groups(self, message: Union[str, Message, MessageSegment]): + """ + 向所有具有命令权限 command permission 的群组发送消息 + """ + # 获取所有需要通知的群组 + command_group_res = await DBBotGroup.list_exist_bot_groups_by_command_permissions(self_bot=self.self_bot, + command_permissions=1) + if command_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription message to command groups, ' + f'getting command group list failed, error: {command_group_res.info}') + return + + for group_id in command_group_res.result: + try: + await self.bot.send_group_msg(group_id=group_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to group: {group_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_enabled_notice_groups(self, message: Union[str, Message, MessageSegment]): + """ + 向所有具有通知权限 notice permission 的群组发送消息 + """ + # 获取所有需要通知的群组 + notice_group_res = await DBBotGroup.list_exist_bot_groups_by_notice_permissions(self_bot=self.self_bot, + notice_permissions=1) + if notice_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription message to notice groups, ' + f'getting notice group list failed, error: {notice_group_res.info}') + return + + for group_id in notice_group_res.result: + try: + await self.bot.send_group_msg(group_id=group_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to group: {group_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_permission_level_groups( + self, permission_level: int, message: Union[str, Message, MessageSegment]): + """ + 向所有大于等于指定权限等级 permission level 的群组发送消息 + """ + # 获取所有需要通知的群组 + level_group_res = await DBBotGroup.list_exist_bot_groups_by_permission_level(self_bot=self.self_bot, + permission_level=permission_level) + if level_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription message to groups had level, ' + f'getting permission level group list failed, error: {level_group_res.info}') + return + + for group_id in level_group_res.result: + try: + await self.bot.send_group_msg(group_id=group_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to group: {group_id} failed, error: {repr(e)}') + continue + + async def safe_send_msg_all_groups(self, message: Union[str, Message, MessageSegment]): + """ + 向所有群组发送消息 + """ + # 获取所有需要通知的群组 + all_group_res = await DBBotGroup.list_exist_bot_groups(self_bot=self.self_bot) + if all_group_res.error: + logger.opt(colors=True).error( + f'{self.log_flag} | Can not send subscription message to all groups, ' + f'getting permission all group list failed, error: {all_group_res.info}') + return + + for group_id in all_group_res.result: + try: + await self.bot.send_group_msg(group_id=group_id, message=message) + except Exception as e: + logger.opt(colors=True).warning( + f'{self.log_flag} | Sending message to group: {group_id} failed, error: {repr(e)}') + continue + + +__all__ = [ + 'MsgSender' +] diff --git a/omega_miya/utils/Omega_plugin_utils/permission.py b/omega_miya/utils/Omega_plugin_utils/permission.py index 2effaab7..b8f0e851 100644 --- a/omega_miya/utils/Omega_plugin_utils/permission.py +++ b/omega_miya/utils/Omega_plugin_utils/permission.py @@ -1,58 +1,54 @@ -from omega_miya.utils.Omega_Base import DBFriend, DBGroup, DBAuth - - -async def check_notice_permission(group_id: int) -> bool: - res = await DBGroup(group_id=group_id).permission_notice() - if res.result == 1: - return True - else: - return False - - -async def check_command_permission(group_id: int) -> bool: - res = await DBGroup(group_id=group_id).permission_command() - if res.result == 1: - return True - else: - return False - - -async def check_permission_level(group_id: int, level: int) -> bool: - res = await DBGroup(group_id=group_id).permission_level() - if res.result >= level: - return True - else: - return False - - -async def check_auth_node(auth_id: int, auth_type: str, auth_node: str) -> int: - auth = DBAuth(auth_id=auth_id, auth_type=auth_type, auth_node=auth_node) - tag_res = await auth.tags_info() - allow_tag = tag_res.result[0] - deny_tag = tag_res.result[1] - - if allow_tag == 1 and deny_tag == 0: - return 1 - elif allow_tag == -2 and deny_tag == -2: - return 0 - else: - return -1 - - -async def check_friend_private_permission(user_id: int) -> bool: - res = await DBFriend(user_id=user_id).get_private_permission() - if res.error: - return False - elif res.result == 1: - return True - else: - return False +from omega_miya.utils.Omega_Base import DBBot, DBFriend, DBBotGroup, DBAuth + + +class PermissionChecker(object): + def __init__(self, self_bot: DBBot): + self.self_bot = self_bot + + async def check_notice_permission(self, group_id: int) -> bool: + res = await DBBotGroup(group_id=group_id, self_bot=self.self_bot).permission_notice() + if res.result == 1: + return True + else: + return False + + async def check_command_permission(self, group_id: int) -> bool: + res = await DBBotGroup(group_id=group_id, self_bot=self.self_bot).permission_command() + if res.result == 1: + return True + else: + return False + + async def check_permission_level(self, group_id: int, level: int) -> bool: + res = await DBBotGroup(group_id=group_id, self_bot=self.self_bot).permission_level() + if res.result >= level: + return True + else: + return False + + async def check_auth_node(self, auth_id: int, auth_type: str, auth_node: str) -> int: + auth = DBAuth(self_bot=self.self_bot, auth_id=auth_id, auth_type=auth_type, auth_node=auth_node) + tag_res = await auth.tags_info() + allow_tag = tag_res.result[0] + deny_tag = tag_res.result[1] + + if allow_tag == 1 and deny_tag == 0: + return 1 + elif allow_tag == -2 and deny_tag == -2: + return 0 + else: + return -1 + + async def check_friend_private_permission(self, user_id: int) -> bool: + res = await DBFriend(user_id=user_id, self_bot=self.self_bot).get_private_permission() + if res.error: + return False + elif res.result == 1: + return True + else: + return False __all__ = [ - 'check_notice_permission', - 'check_command_permission', - 'check_permission_level', - 'check_auth_node', - 'check_friend_private_permission' + 'PermissionChecker' ] diff --git a/omega_miya/utils/Omega_plugin_utils/picture_effector.py b/omega_miya/utils/Omega_plugin_utils/picture_effector.py new file mode 100644 index 00000000..0e6c4fe1 --- /dev/null +++ b/omega_miya/utils/Omega_plugin_utils/picture_effector.py @@ -0,0 +1,94 @@ +""" +@Author : Ailitonia +@Date : 2021/06/02 0:35 +@FileName : picture_effector.py +@Project : nonebot2_miya +@Description : Picture Effector +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import asyncio +import random +from typing import Optional +from io import BytesIO +from PIL import Image, ImageFilter, ImageEnhance +from omega_miya.utils.Omega_Base import Result + + +class PicEffector(object): + def __init__(self, image: bytes): + self.image = image + + def add_blank_bytes(self, bytes_num: int = 16) -> bytes: + return self.image + b' '*bytes_num + + async def gaussian_blur(self, radius: Optional[int] = None) -> Result.BytesResult: + def __handle() -> Result.BytesResult: + with BytesIO() as byte_file: + byte_file.write(self.image) + # 处理图片 + mk_image = Image.open(byte_file) + if radius: + blur_radius = radius + else: + blur_radius = mk_image.width // 16 + blur_image = mk_image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + with BytesIO() as mk_byte_file: + blur_image.save(mk_byte_file, format='PNG') + img_bytes = mk_byte_file.getvalue() + return Result.BytesResult(error=False, info='Success', result=img_bytes) + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor(None, __handle) + except Exception as e: + result = Result.BytesResult(error=True, info=f'gaussian_blur failed: {repr(e)}', result=b'') + return result + + async def gaussian_noise( + self, + *, + sigma: Optional[float] = 8, + enable_random: bool = True, + mask_factor: Optional[float] = 0.25) -> Result.BytesResult: + """ + 为图片添加肉眼不可见的底噪 + :param sigma: 噪声sigma, 默认值8 + :param enable_random: 为噪声sigma添加随机扰动, 默认值True + :param mask_factor: 噪声蒙版透明度修正, 默认值0.25 + :return: + """ + def __handle() -> Result.BytesResult: + with BytesIO() as byte_file: + byte_file.write(self.image) + # 处理图片 + mk_image: Image.Image = Image.open(byte_file) + width, height = mk_image.width, mk_image.height + # 为sigma添加随机扰动 + if enable_random: + sigma_ = sigma * (1 + 0.1 * random.random()) + else: + sigma_ = sigma + # 生成高斯噪声底图 + noise_image = Image.effect_noise(size=(width, height), sigma=sigma_) + # 生成底噪蒙版 + noise_mask = ImageEnhance.Brightness(noise_image.convert('L')).enhance(factor=mask_factor) + with BytesIO() as mk_byte_file: + # 叠加噪声图层 + mk_image.paste(noise_image, (0, 0), mask=noise_mask) + mk_image.save(mk_byte_file, format='PNG') + img_bytes = mk_byte_file.getvalue() + return Result.BytesResult(error=False, info='Success', result=img_bytes) + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor(None, __handle) + except Exception as e: + result = Result.BytesResult(error=True, info=f'gaussian_noise failed: {repr(e)}', result=b'') + return result + + +__all__ = [ + 'PicEffector' +] diff --git a/omega_miya/utils/Omega_plugin_utils/picture_encoder.py b/omega_miya/utils/Omega_plugin_utils/picture_encoder.py index 61e166a4..9bf30feb 100644 --- a/omega_miya/utils/Omega_plugin_utils/picture_encoder.py +++ b/omega_miya/utils/Omega_plugin_utils/picture_encoder.py @@ -1,47 +1,88 @@ import os import base64 -from nonebot import logger -from dataclasses import dataclass +import aiofiles +import pathlib +from nonebot import get_driver, logger +from omega_miya.utils.Omega_Base import Result +from omega_miya.utils.Omega_plugin_utils import HttpFetcher -class PicEncoder(object): - @dataclass - class __Result: - error: bool - info: str - result: str - - def success(self) -> bool: - if not self.error: - return True - else: - return False +driver = get_driver() +TMP_PATH = driver.config.tmp_path_ + +class PicEncoder(object): @classmethod - def file_to_b64(cls, file_path: str) -> __Result: + async def file_to_b64(cls, file_path: str) -> Result.TextResult: abs_path = os.path.abspath(file_path) if not os.path.exists(abs_path): - return cls.__Result(error=True, info='File not exists', result='') + return Result.TextResult(error=True, info='File not exists', result='') try: - with open(abs_path, 'rb') as f: - b64 = base64.b64encode(f.read()) + async with aiofiles.open(abs_path, 'rb') as af: + b64 = base64.b64encode(await af.read()) b64 = str(b64, encoding='utf-8') b64 = 'base64://' + b64 - return cls.__Result(error=True, info='Success', result=b64) + return Result.TextResult(error=False, info='Success', result=b64) except Exception as e: logger.opt(colors=True).warning(f'PicEncoder file_to_b64 failed, Error: {repr(e)}') - return cls.__Result(error=True, info=repr(e), result='') + return Result.TextResult(error=True, info=repr(e), result='') @classmethod - def bytes_to_b64(cls, image: bytes) -> __Result: + async def bytes_to_file(cls, image: bytes, *, folder_flag: str = 'PicEncoder') -> Result.TextResult: + # 检查保存文件路径 + folder_path = os.path.abspath(os.path.join(TMP_PATH, folder_flag)) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + file_path = os.path.abspath(os.path.join(folder_path, str(hash(image)))) + try: + async with aiofiles.open(file_path, 'wb') as af: + await af.write(image) + file_url = pathlib.Path(file_path).as_uri() + return Result.TextResult(error=False, info='Success', result=file_url) + except Exception as e: + logger.opt(colors=True).warning(f'PicEncoder bytes_to_file failed, Error: {repr(e)}') + return Result.TextResult(error=True, info=repr(e), result='') + + @classmethod + def bytes_to_b64(cls, image: bytes) -> Result.TextResult: try: b64 = str(base64.b64encode(image), encoding='utf-8') b64 = 'base64://' + b64 - return cls.__Result(error=False, info='Success', result=b64) + return Result.TextResult(error=False, info='Success', result=b64) except Exception as e: logger.opt(colors=True).warning(f'PicEncoder bytes_to_b64 failed, Error: {repr(e)}') - return cls.__Result(error=True, info=repr(e), result='') + return Result.TextResult(error=True, info=repr(e), result='') + + def __init__( + self, + pic_url: str, + *, + headers: dict = None, + params: dict = None + ): + self.__pic_url = pic_url + self.__headers = headers + self.__params = params + + async def get_base64(self) -> Result.TextResult: + fetcher = HttpFetcher(timeout=30, attempt_limit=2, flag='PicEncoder_get_base64', headers=self.__headers) + bytes_result = await fetcher.get_bytes(url=self.__pic_url) + if bytes_result.error: + return Result.TextResult(error=True, info='Image download failed', result='') + + encode_result = self.bytes_to_b64(image=bytes_result.result) + return encode_result + + async def get_file(self, *, folder_flag: str = 'PicEncoder') -> Result.TextResult: + fetcher = HttpFetcher(timeout=30, attempt_limit=2, flag='PicEncoder_get_base64', headers=self.__headers) + bytes_result = await fetcher.get_bytes(url=self.__pic_url) + if bytes_result.error: + return Result.TextResult(error=True, info='Image download failed', result='') + + encode_result = await self.bytes_to_file(image=bytes_result.result, folder_flag=folder_flag) + return encode_result __all__ = [ diff --git a/omega_miya/utils/Omega_plugin_utils/process_utils.py b/omega_miya/utils/Omega_plugin_utils/process_utils.py new file mode 100644 index 00000000..4e86e142 --- /dev/null +++ b/omega_miya/utils/Omega_plugin_utils/process_utils.py @@ -0,0 +1,72 @@ +""" +@Author : Ailitonia +@Date : 2021/07/29 19:29 +@FileName : process_utils.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import asyncio +from typing import List, Tuple, Awaitable, Optional, Any +from nonebot import logger + + +class ProcessUtils(object): + @classmethod + async def fragment_process( + cls, + tasks: List[Awaitable[Any]], + fragment_size: Optional[int] = None, + *, + log_flag: str = 'Default') -> Tuple: + """ + 分段运行一批需要并行的异步函数 + :param tasks: 任务序列 + :param fragment_size: 单次并行的数量 + :param log_flag: 日志标记 + """ + all_count = len(tasks) + if all_count <= 0: + raise ValueError('Param "tasks" must not be null') + elif not isinstance(fragment_size, int): + raise ValueError('Param "fragment_size" must be int') + elif not fragment_size: + fragment_size = all_count + elif fragment_size > all_count: + fragment_size = all_count + elif fragment_size <= 0: + raise ValueError('Param "fragment_size" must be int') + + # 切分切片列表 + fragment_list = [] + for i in range(0, all_count, fragment_size): + fragment_list.append(tasks[i:i + fragment_size]) + fragment_count = len(fragment_list) + + # 执行进度及统计计数 + process_rate_count = 0 + # 最终返回的结果 + result = [] + # 每个切片打包一个任务 + for fragment in fragment_list: + # 进行异步处理 + try: + _result = await asyncio.gather(*fragment) + result.extend(_result) + except Exception as e: + logger.error(f'Fragment process | {log_flag} processing error: {repr(e)}') + continue + + # 显示进度 + process_rate_count += 1 + logger.info(f'Fragment process | {log_flag} processing: {process_rate_count}/{fragment_count}') + + logger.info(f'Fragment process | {log_flag} process complete, total tasks: {all_count}') + return tuple(result) + + +__all__ = [ + 'ProcessUtils' +] diff --git a/omega_miya/utils/Omega_plugin_utils/rules.py b/omega_miya/utils/Omega_plugin_utils/rules.py index a762c885..58698c19 100644 --- a/omega_miya/utils/Omega_plugin_utils/rules.py +++ b/omega_miya/utils/Omega_plugin_utils/rules.py @@ -2,158 +2,158 @@ from nonebot.typing import T_State from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import Event -from omega_miya.utils.Omega_Base import DBFriend, DBGroup, DBAuth - - -# Plugin permission rule -# Only using for group -def has_notice_permission() -> Rule: - async def _has_notice_permission(bot: Bot, event: Event, state: T_State) -> bool: - detail_type = event.dict().get(f'{event.get_type()}_type') - group_id = event.dict().get('group_id') - # 检查当前消息类型 - if detail_type != 'group': - return False - else: - res = await DBGroup(group_id=group_id).permission_notice() - if res.result == 1: - return True +from omega_miya.utils.Omega_Base import DBBot, DBFriend, DBBotGroup, DBAuth + + +class OmegaRules(object): + # Plugin permission rule + # Only using for group + @classmethod + def has_group_notice_permission(cls) -> Rule: + async def _has_group_notice_permission(bot: Bot, event: Event, state: T_State) -> bool: + detail_type = event.dict().get(f'{event.get_type()}_type') + group_id = event.dict().get('group_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查当前消息类型 + if not str(detail_type).startswith('group'): + return False else: + res = await DBBotGroup(group_id=group_id, self_bot=self_bot).permission_notice() + if res.result == 1: + return True + else: + return False + return Rule(_has_group_notice_permission) + + @classmethod + def has_group_command_permission(cls) -> Rule: + async def _has_group_command_permission(bot: Bot, event: Event, state: T_State) -> bool: + detail_type = event.dict().get(f'{event.get_type()}_type') + group_id = event.dict().get('group_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查当前消息类型 + if not str(detail_type).startswith('group'): return False - return Rule(_has_notice_permission) - - -def has_command_permission() -> Rule: - async def _has_command_permission(bot: Bot, event: Event, state: T_State) -> bool: - detail_type = event.dict().get(f'{event.get_type()}_type') - group_id = event.dict().get('group_id') - # 检查当前消息类型 - if detail_type != 'group': - return False - else: - res = await DBGroup(group_id=group_id).permission_command() - if res.result == 1: - return True else: + res = await DBBotGroup(group_id=group_id, self_bot=self_bot).permission_command() + if res.result == 1: + return True + else: + return False + return Rule(_has_group_command_permission) + + @classmethod + def has_group_permission_level(cls, level: int) -> Rule: + async def _has_group_permission_level(bot: Bot, event: Event, state: T_State) -> bool: + detail_type = event.dict().get(f'{event.get_type()}_type') + group_id = event.dict().get('group_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查当前消息类型 + if not str(detail_type).startswith('group'): return False - return Rule(_has_command_permission) - + else: + res = await DBBotGroup(group_id=group_id, self_bot=self_bot).permission_level() + if res.result >= level: + return True + else: + return False + return Rule(_has_group_permission_level) + + # 权限节点检查 + @classmethod + def has_auth_node(cls, *auth_nodes: str) -> Rule: + async def _has_auth_node(bot: Bot, event: Event, state: T_State) -> bool: + auth_node = '.'.join(auth_nodes) + detail_type = event.dict().get(f'{event.get_type()}_type') + group_id = event.dict().get('group_id') + user_id = event.dict().get('user_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查当前消息类型 + if detail_type == 'private': + user_auth = DBAuth(self_bot=self_bot, auth_id=user_id, auth_type='user', auth_node=auth_node) + user_tag_res = await user_auth.tags_info() + allow_tag = user_tag_res.result[0] + deny_tag = user_tag_res.result[1] + elif str(detail_type).startswith('group'): + group_auth = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=auth_node) + group_tag_res = await group_auth.tags_info() + allow_tag = group_tag_res.result[0] + deny_tag = group_tag_res.result[1] + else: + allow_tag = 0 + deny_tag = 0 -# 规划权限等级(暂定): 10+一般插件, 20各类订阅插件, 30+限制插件(涉及调用api), 50+涩图插件 -def permission_level(level: int) -> Rule: - async def _has_permission_level(bot: Bot, event: Event, state: T_State) -> bool: - detail_type = event.dict().get(f'{event.get_type()}_type') - group_id = event.dict().get('group_id') - # 检查当前消息类型 - if detail_type != 'group': - return False - else: - res = await DBGroup(group_id=group_id).permission_level() - if res.result >= level: + if allow_tag == 1 and deny_tag == 0: return True else: return False - return Rule(_has_permission_level) - - -# 权限节点检查 -def has_auth_node(*auth_nodes: str) -> Rule: - async def _has_auth_node(bot: Bot, event: Event, state: T_State) -> bool: - auth_node = '.'.join(auth_nodes) - detail_type = event.dict().get(f'{event.get_type()}_type') - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - # 检查当前消息类型 - if detail_type == 'private': - user_auth = DBAuth(auth_id=user_id, auth_type='user', auth_node=auth_node) - user_tag_res = await user_auth.tags_info() - allow_tag = user_tag_res.result[0] - deny_tag = user_tag_res.result[1] - elif detail_type == 'group' or detail_type == 'group_upload': - group_auth = DBAuth(auth_id=group_id, auth_type='group', auth_node=auth_node) - group_tag_res = await group_auth.tags_info() - allow_tag = group_tag_res.result[0] - deny_tag = group_tag_res.result[1] - else: - allow_tag = 0 - deny_tag = 0 - - if allow_tag == 1 and deny_tag == 0: - return True - else: - return False - return Rule(_has_auth_node) - - -# 由于目前nb2暂不支持or连接rule, 因此将or逻辑放在rule内处理 -def has_level_or_node(level: int, *auth_nodes: str) -> Rule: - """ - :param level: 需要群组权限等级 - :param auth_nodes: 需要的权限节点 - :return: 群组权限等级大于要求等级或者具备权限节点, 权限节点为deny则拒绝 - """ - async def _has_level_or_node(bot: Bot, event: Event, state: T_State) -> bool: - auth_node = '.'.join(auth_nodes) - detail_type = event.dict().get(f'{event.get_type()}_type') - group_id = event.dict().get('group_id') - user_id = event.dict().get('user_id') - - # level检查部分 - if detail_type != 'group': - level_checker = False - else: - level_res = await DBGroup(group_id=group_id).permission_level() - if level_res.result >= level: - level_checker = True - else: + return Rule(_has_auth_node) + + # 由于目前nb2暂不支持or连接rule, 因此将or逻辑放在rule内处理 + @classmethod + def has_level_or_node(cls, level: int, auth_node: str) -> Rule: + """ + :param level: 需要群组权限等级 + :param auth_node: 需要的权限节点 + :return: 群组权限等级大于要求等级或者具备权限节点, 权限节点为deny则拒绝 + """ + async def _has_level_or_node(bot: Bot, event: Event, state: T_State) -> bool: + detail_type = event.dict().get(f'{event.get_type()}_type') + group_id = event.dict().get('group_id') + user_id = event.dict().get('user_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + + # level检查部分 + if detail_type != 'group': level_checker = False + else: + level_res = await DBBotGroup(group_id=group_id, self_bot=self_bot).permission_level() + if level_res.result >= level: + level_checker = True + else: + level_checker = False + + # node检查部分 + if detail_type == 'private': + user_auth = DBAuth(self_bot=self_bot, auth_id=user_id, auth_type='user', auth_node=auth_node) + user_tag_res = await user_auth.tags_info() + allow_tag = user_tag_res.result[0] + deny_tag = user_tag_res.result[1] + elif str(detail_type).startswith('group'): + group_auth = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=auth_node) + group_tag_res = await group_auth.tags_info() + allow_tag = group_tag_res.result[0] + deny_tag = group_tag_res.result[1] + else: + allow_tag = 0 + deny_tag = 0 - # node检查部分 - if detail_type == 'private': - user_auth = DBAuth(auth_id=user_id, auth_type='user', auth_node=auth_node) - user_tag_res = await user_auth.tags_info() - allow_tag = user_tag_res.result[0] - deny_tag = user_tag_res.result[1] - elif detail_type == 'group': - group_auth = DBAuth(auth_id=group_id, auth_type='group', auth_node=auth_node) - group_tag_res = await group_auth.tags_info() - allow_tag = group_tag_res.result[0] - deny_tag = group_tag_res.result[1] - else: - allow_tag = 0 - deny_tag = 0 - - if allow_tag == 1 and deny_tag == 0: - return True - elif allow_tag == -2 and deny_tag == -2: - return level_checker - else: - return False - - return Rule(_has_level_or_node) - - -def has_friend_private_permission() -> Rule: - async def _has_friend_private_permission(bot: Bot, event: Event, state: T_State) -> bool: - detail_type = event.dict().get(f'{event.get_type()}_type') - user_id = event.dict().get('user_id') - # 检查当前消息类型 - if detail_type != 'private': - return False - else: - res = await DBFriend(user_id=user_id).get_private_permission() - if res.result == 1: + if allow_tag == 1 and deny_tag == 0: return True + elif allow_tag == -2 and deny_tag == -2: + return level_checker else: return False - return Rule(_has_friend_private_permission) + return Rule(_has_level_or_node) + + @classmethod + def has_friend_private_permission(cls) -> Rule: + async def _has_friend_private_permission(bot: Bot, event: Event, state: T_State) -> bool: + detail_type = event.dict().get(f'{event.get_type()}_type') + user_id = event.dict().get('user_id') + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查当前消息类型 + if detail_type != 'private': + return False + else: + res = await DBFriend(user_id=user_id, self_bot=self_bot).get_private_permission() + if res.result == 1: + return True + else: + return False + return Rule(_has_friend_private_permission) __all__ = [ - 'has_notice_permission', - 'has_command_permission', - 'has_auth_node', - 'has_level_or_node', - 'permission_level', - 'has_friend_private_permission' + 'OmegaRules' ] diff --git a/omega_miya/utils/Omega_processor/__init__.py b/omega_miya/utils/Omega_processor/__init__.py new file mode 100644 index 00000000..2ecd6145 --- /dev/null +++ b/omega_miya/utils/Omega_processor/__init__.py @@ -0,0 +1,52 @@ +""" +@Author : Ailitonia +@Date : 2021/07/09 19:49 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : 集合全部 processor 统一处理冷却、权限等 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from typing import Optional +from nonebot.message import event_preprocessor, event_postprocessor, run_preprocessor, run_postprocessor +from nonebot.typing import T_State +from nonebot.matcher import Matcher +from nonebot.adapters.cqhttp.event import Event, MessageEvent +from nonebot.adapters.cqhttp.bot import Bot +from .permission import preprocessor_permission +from .cooldown import preprocessor_cooldown +from .history import postprocessor_history + + +# 事件预处理 +@event_preprocessor +async def handle_event_preprocessor(bot: Bot, event: Event, state: T_State): + pass + + +# 运行预处理 +@run_preprocessor +async def handle_run_preprocessor(matcher: Matcher, bot: Bot, event: Event, state: T_State): + # 处理权限 + if isinstance(event, MessageEvent): + await preprocessor_permission(matcher=matcher, bot=bot, event=event, state=state) + + # 处理冷却 + if isinstance(event, MessageEvent): + await preprocessor_cooldown(matcher=matcher, bot=bot, event=event, state=state) + + +# 运行后处理 +@run_postprocessor +async def handle_run_postprocessor( + matcher: Matcher, exception: Optional[Exception], bot: Bot, event: Event, state: T_State): + # 处理插件统计 + pass + + +# 事件后处理 +@event_postprocessor +async def handle_event_postprocessor(bot: Bot, event: Event, state: T_State): + # 处理历史记录 + await postprocessor_history(bot=bot, event=event, state=state) diff --git a/omega_miya/utils/Omega_CoolDown/__init__.py b/omega_miya/utils/Omega_processor/cooldown.py similarity index 76% rename from omega_miya/utils/Omega_CoolDown/__init__.py rename to omega_miya/utils/Omega_processor/cooldown.py index ebc33b98..3b540fe3 100644 --- a/omega_miya/utils/Omega_CoolDown/__init__.py +++ b/omega_miya/utils/Omega_processor/cooldown.py @@ -7,27 +7,28 @@ from nonebot import get_plugin, get_driver, logger from nonebot.adapters.cqhttp import MessageSegment, Message from nonebot.exception import IgnoredException -from nonebot.message import run_preprocessor from nonebot.typing import T_State from nonebot.matcher import Matcher from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent -from omega_miya.utils.Omega_plugin_utils import \ - check_and_set_global_cool_down, check_and_set_plugin_cool_down, \ - check_and_set_group_cool_down, check_and_set_user_cool_down, PluginCoolDown -from omega_miya.utils.Omega_Base import DBCoolDownEvent, DBAuth +from omega_miya.utils.Omega_plugin_utils import PluginCoolDown +from omega_miya.utils.Omega_Base import DBCoolDownEvent, DBAuth, DBBot -@run_preprocessor -async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent, state: T_State): +global_config = get_driver().config +SUPERUSERS = global_config.superusers + + +async def preprocessor_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent, state: T_State): + """ + 冷却处理 T_RunPreProcessor + """ + group_id = event.dict().get('group_id') user_id = event.dict().get('user_id') - global_config = get_driver().config - superusers = global_config.superusers - # 忽略超级用户 - if user_id in [int(x) for x in superusers]: + if user_id in [int(x) for x in SUPERUSERS]: return # 只处理message事件 @@ -45,14 +46,21 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent if not plugin_cool_down_list: return + # 跳过由 got 等事件处理函数创建临时 matcher 避免冷却在命令交互中被不正常触发 + if matcher.temp: + return + + # 处理不同bot权限 + self_bot = DBBot(self_qq=int(bot.self_id)) + # 检查用户或群组是否有skip_cd权限, 跳过冷却检查 skip_cd_auth_node = f'{plugin_name}.{PluginCoolDown.skip_auth_node}' - user_auth = DBAuth(auth_id=user_id, auth_type='user', auth_node=skip_cd_auth_node) + user_auth = DBAuth(self_bot=self_bot, auth_id=user_id, auth_type='user', auth_node=skip_cd_auth_node) user_tag_res = await user_auth.tags_info() if user_tag_res.result[0] == 1 and user_tag_res.result[1] == 0: return - group_auth = DBAuth(auth_id=group_id, auth_type='group', auth_node=skip_cd_auth_node) + group_auth = DBAuth(self_bot=self_bot, auth_id=group_id, auth_type='group', auth_node=skip_cd_auth_node) group_tag_res = await group_auth.tags_info() if group_tag_res.result[0] == 1 and group_tag_res.result[1] == 0: return @@ -70,7 +78,7 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent elif global_check.result == 1: await bot.send(event=event, message=Message(f'{MessageSegment.at(user_id=user_id)}命令冷却中!\n{global_check.info}')) raise IgnoredException('全局命令冷却中') - elif global_check.result == 0: + elif global_check.result in [0, 2]: pass else: logger.error(f'全局冷却事件异常! group: {group_id}, user: {user_id}, error: {global_check.info}') @@ -80,11 +88,11 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent if plugin_check.result == 1 or group_check.result == 1 or user_check.result == 1: break - res = await check_and_set_global_cool_down(minutes=time) + res = await PluginCoolDown.check_and_set_global_cool_down(minutes=time) if res.result == 1: await bot.send(event=event, message=Message(f'{MessageSegment.at(user_id=user_id)}命令冷却中!\n{res.info}')) raise IgnoredException('全局命令冷却中') - elif res.result == 0: + elif res.result in [0, 2]: pass else: logger.error(f'全局冷却事件异常! group: {group_id}, user: {user_id}, error: {res.info}') @@ -95,11 +103,11 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent if group_check.result == 1 or user_check.result == 1: break - res = await check_and_set_plugin_cool_down(minutes=time, plugin=plugin_name) + res = await PluginCoolDown.check_and_set_plugin_cool_down(minutes=time, plugin=plugin_name) if res.result == 1: await bot.send(event=event, message=Message(f'{MessageSegment.at(user_id=user_id)}命令冷却中!\n{res.info}')) raise IgnoredException('插件命令冷却中') - elif res.result == 0: + elif res.result in [0, 2]: pass else: logger.error(f'插件冷却事件异常! group: {group_id}, user: {user_id}, plugin: {plugin_name}, error: {res.info}') @@ -113,11 +121,11 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent if user_check.result == 1: break - res = await check_and_set_group_cool_down(minutes=time, plugin=plugin_name, group_id=group_id) + res = await PluginCoolDown.check_and_set_group_cool_down(minutes=time, plugin=plugin_name, group_id=group_id) if res.result == 1: await bot.send(event=event, message=Message(f'{MessageSegment.at(user_id=user_id)}命令冷却中!\n{res.info}')) raise IgnoredException('群组命令冷却中') - elif res.result == 0: + elif res.result in [0, 2]: pass else: logger.error(f'群组冷却事件异常! group: {group_id}, user: {user_id}, plugin: {plugin_name}, error: {res.info}') @@ -127,11 +135,16 @@ async def handle_plugin_cooldown(matcher: Matcher, bot: Bot, event: MessageEvent if not user_id: break - res = await check_and_set_user_cool_down(minutes=time, plugin=plugin_name, user_id=user_id) + res = await PluginCoolDown.check_and_set_user_cool_down(minutes=time, plugin=plugin_name, user_id=user_id) if res.result == 1: await bot.send(event=event, message=Message(f'{MessageSegment.at(user_id=user_id)}命令冷却中!\n{res.info}')) raise IgnoredException('用户命令冷却中') - elif res.result == 0: + elif res.result in [0, 2]: pass else: logger.error(f'用户冷却事件异常! group: {group_id}, user: {user_id}, plugin: {plugin_name}, error: {res.info}') + + +__all__ = [ + 'preprocessor_cooldown' +] diff --git a/omega_miya/utils/Omega_processor/history.py b/omega_miya/utils/Omega_processor/history.py new file mode 100644 index 00000000..4cefdf1c --- /dev/null +++ b/omega_miya/utils/Omega_processor/history.py @@ -0,0 +1,81 @@ +""" +@Author : Ailitonia +@Date : 2021/07/29 4:14 +@FileName : history.py +@Project : nonebot2_miya +@Description : 历史记录模块 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from nonebot import logger +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import (Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, + NoticeEvent, RequestEvent, MetaEvent) +from omega_miya.utils.Omega_Base import DBHistory + + +async def postprocessor_history(bot: Bot, event: Event, state: T_State): + try: + time = event.time + self_id = event.self_id + post_type = event.post_type + if isinstance(event, MetaEvent): + # 不记录元事件 + return + elif isinstance(event, MessageEvent): + detail_type = event.message_type + sub_type = event.sub_type + event_id = event.message_id + if isinstance(event, GroupMessageEvent): + group_id = event.group_id + elif isinstance(event, PrivateMessageEvent): + group_id = 0 + else: + group_id = -1 + user_id = event.user_id + user_name = f'{event.sender.nickname}/{event.sender.card}' + raw_data = repr(event) + msg_data = str(event.message) + elif isinstance(event, NoticeEvent): + detail_type = event.notice_type + sub_type = event.dict().get('sub_type', 'Undefined') + event_id = -1 + group_id = event.dict().get('group_id', -1) + user_id = event.dict().get('user_id', -1) + user_name = '' + raw_data = repr(event) + msg_data = '' + elif isinstance(event, RequestEvent): + detail_type = event.request_type + sub_type = event.dict().get('sub_type', 'Undefined') + event_id = -1 + group_id = event.dict().get('group_id', -1) + user_id = event.dict().get('user_id', -1) + user_name = '' + raw_data = repr(event) + msg_data = '' + else: + detail_type = event.get_event_name() + sub_type = event.dict().get('sub_type', 'Undefined') + event_id = -1 + group_id = event.dict().get('group_id', -1) + user_id = event.dict().get('user_id', -1) + user_name = '' + raw_data = repr(event) + msg_data = str(event.dict().get('message')) + + new_history = DBHistory(time=time, self_id=self_id, post_type=post_type, detail_type=detail_type) + add_result = await new_history.add( + sub_type=sub_type, event_id=event_id, group_id=group_id, user_id=user_id, user_name=user_name, + raw_data=raw_data, msg_data=msg_data) + if add_result.error: + logger.error(f'History recording failed with database error: {add_result.info}, event: {repr(event)}') + except Exception as e: + logger.error(f'History recording Failed, error: {repr(e)}, event: {repr(event)}') + + +__all__ = [ + 'postprocessor_history' +] diff --git a/omega_miya/utils/Omega_Permission/__init__.py b/omega_miya/utils/Omega_processor/permission.py similarity index 63% rename from omega_miya/utils/Omega_Permission/__init__.py rename to omega_miya/utils/Omega_processor/permission.py index 48c9c267..4045d13a 100644 --- a/omega_miya/utils/Omega_Permission/__init__.py +++ b/omega_miya/utils/Omega_processor/permission.py @@ -1,16 +1,22 @@ -from nonebot import get_driver, logger +from nonebot import get_driver from nonebot.exception import IgnoredException -from nonebot.message import run_preprocessor from nonebot.typing import T_State from nonebot.matcher import Matcher from nonebot.adapters.cqhttp.bot import Bot from nonebot.adapters.cqhttp.event import MessageEvent, GroupMessageEvent, PrivateMessageEvent -from omega_miya.utils.Omega_plugin_utils import \ - check_command_permission, check_permission_level, check_auth_node, check_friend_private_permission +from omega_miya.utils.Omega_plugin_utils import PermissionChecker +from omega_miya.utils.Omega_Base import DBBot -@run_preprocessor -async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEvent, state: T_State): +global_config = get_driver().config +SUPERUSERS = global_config.superusers + + +async def preprocessor_permission(matcher: Matcher, bot: Bot, event: MessageEvent, state: T_State): + """ + 权限处理 T_RunPreProcessor + """ + if isinstance(event, PrivateMessageEvent): private_mode = True elif isinstance(event, GroupMessageEvent): @@ -21,11 +27,8 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve group_id = event.dict().get('group_id') user_id = event.dict().get('user_id') - global_config = get_driver().config - superusers = global_config.superusers - # 忽略超级用户 - if user_id in [int(x) for x in superusers]: + if user_id in [int(x) for x in SUPERUSERS]: return matcher_default_state = matcher.state @@ -33,10 +36,14 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve matcher_permission_level = matcher_default_state.get('_permission_level') matcher_auth_node = matcher_default_state.get('_auth_node') + # 处理不同bot权限 + self_bot = DBBot(self_qq=int(bot.self_id)) + permission_checker = PermissionChecker(self_bot=self_bot) + # 检查command/friend_private权限 if private_mode: if matcher_command_permission: - command_checker = await check_friend_private_permission(user_id=user_id) + command_checker = await permission_checker.check_friend_private_permission(user_id=user_id) if command_checker: pass else: @@ -44,7 +51,7 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve raise IgnoredException('没有好友命令权限') else: if matcher_command_permission: - command_checker = await check_command_permission(group_id=group_id) + command_checker = await permission_checker.check_command_permission(group_id=group_id) if command_checker: pass else: @@ -56,7 +63,8 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve if private_mode: level_checker = True else: - level_checker = await check_permission_level(group_id=group_id, level=matcher_permission_level) + level_checker = await permission_checker.check_permission_level(group_id=group_id, + level=matcher_permission_level) else: level_checker = False @@ -64,11 +72,17 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve if matcher_auth_node: auth_node = '.'.join([matcher.module, matcher_auth_node]) # 分别检查用户及群组权限节点 - user_auth_checker = await check_auth_node(auth_id=user_id, auth_type='user', auth_node=auth_node) + user_auth_checker = await permission_checker.check_auth_node(auth_id=user_id, + auth_type='user', + auth_node=auth_node) + if private_mode: group_auth_checker = 0 else: - group_auth_checker = await check_auth_node(auth_id=group_id, auth_type='group', auth_node=auth_node) + group_auth_checker = await permission_checker.check_auth_node(auth_id=group_id, + auth_type='group', + auth_node=auth_node) + # 优先级: 用户权限节点>群组权限节点>权限等级 if user_auth_checker == -1 or group_auth_checker == -1: await bot.send(event=event, message=f'权限受限QAQ') @@ -87,3 +101,8 @@ async def handle_plugin_permission(matcher: Matcher, bot: Bot, event: MessageEve elif matcher_permission_level and not level_checker: await bot.send(event=event, message=f'群组权限等级不足QAQ') raise IgnoredException('群组权限等级不足') + + +__all__ = [ + 'preprocessor_permission' +] diff --git a/omega_miya/utils/bilibili_utils/dynamic.py b/omega_miya/utils/bilibili_utils/dynamic.py index e4621d03..0374fd66 100644 --- a/omega_miya/utils/bilibili_utils/dynamic.py +++ b/omega_miya/utils/bilibili_utils/dynamic.py @@ -107,7 +107,7 @@ def data_parser(cls, dynamic_data: dict) -> BiliResult.DynamicInfoResult: # type=1, 这是一条转发的动态 if type_ == 1: origin_user = dynamic_card.get('origin_user') - if origin_user: + if origin_user and origin_user['info'].get('uname'): origin_user_name = origin_user['info'].get('uname') desc = f'转发了{origin_user_name}的动态' else: @@ -175,6 +175,13 @@ def data_parser(cls, dynamic_data: dict) -> BiliResult.DynamicInfoResult: content = dynamic_card['vest']['content'] title = dynamic_card['sketch']['title'] description = dynamic_card['sketch']['desc_text'] + # type=4200, 直播间动态(疑似) + elif type_ == 4200: + desc = '发布了一条直播间动态' + content = f"{dynamic_card['uname']}的直播间 - {dynamic_card['title']}" + pictures.append(dynamic_card['cover']) + title = dynamic_card['title'] + description = None # 其他未知类型 else: desc = 'Unknown' diff --git a/omega_miya/utils/bilibili_utils/request_utils.py b/omega_miya/utils/bilibili_utils/request_utils.py index 6bdca258..da74a0b9 100644 --- a/omega_miya/utils/bilibili_utils/request_utils.py +++ b/omega_miya/utils/bilibili_utils/request_utils.py @@ -75,7 +75,7 @@ async def verify_cookies(self) -> Result.TextResult: @classmethod # 图片转base64 - async def pic_2_base64(cls, url: str) -> Result.TextResult: + async def pic_to_base64(cls, url: str) -> Result.TextResult: headers = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' 'Chrome/89.0.4389.114 Safari/537.36', 'origin': 'https://www.bilibili.com', @@ -94,6 +94,26 @@ async def pic_2_base64(cls, url: str) -> Result.TextResult: else: return Result.TextResult(error=True, info=encode_result.info, result='') + @classmethod + async def pic_to_file(cls, url: str) -> Result.TextResult: + headers = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/89.0.4389.114 Safari/537.36', + 'origin': 'https://www.bilibili.com', + 'referer': 'https://www.bilibili.com/'} + + fetcher = HttpFetcher( + timeout=30, attempt_limit=2, flag='bilibili_live_monitor_get_image', headers=headers) + bytes_result = await fetcher.get_bytes(url=url) + if bytes_result.error: + return Result.TextResult(error=True, info='Image download failed', result='') + + encode_result = await PicEncoder.bytes_to_file(image=bytes_result.result, folder_flag='bilibili') + + if encode_result.success(): + return Result.TextResult(error=False, info='Success', result=encode_result.result) + else: + return Result.TextResult(error=True, info=encode_result.info, result='') + __all__ = [ 'BiliRequestUtils' diff --git a/omega_miya/utils/bilibili_utils/user.py b/omega_miya/utils/bilibili_utils/user.py index 0ce03687..22a51ccb 100644 --- a/omega_miya/utils/bilibili_utils/user.py +++ b/omega_miya/utils/bilibili_utils/user.py @@ -80,6 +80,8 @@ async def get_dynamic_history(self) -> Result.DictListResult: return Result.DictListResult(error=True, info=result.result.get('message'), result=[]) try: + if not result.result['data'].get('cards'): + return Result.DictListResult(error=False, info='Success. But user has no dynamic.', result=[]) data_list = [dict(card) for card in result.result['data']['cards']] return Result.DictListResult(error=False, info='Success', result=data_list) except Exception as e: diff --git a/omega_miya/utils/dice_utils/__init__.py b/omega_miya/utils/dice_utils/__init__.py new file mode 100644 index 00000000..f85d1791 --- /dev/null +++ b/omega_miya/utils/dice_utils/__init__.py @@ -0,0 +1,18 @@ +""" +@Author : Ailitonia +@Date : 2021/07/18 1:24 +@FileName : __init__.py.py +@Project : nonebot2_miya +@Description : 掷骰及计算工具包 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from .calculator import BaseCalculator +from .dice import BaseDice + + +__all__ = [ + 'BaseCalculator', + 'BaseDice' +] diff --git a/omega_miya/utils/dice_utils/calculator.py b/omega_miya/utils/dice_utils/calculator.py new file mode 100644 index 00000000..6b254968 --- /dev/null +++ b/omega_miya/utils/dice_utils/calculator.py @@ -0,0 +1,179 @@ +""" +@Author : Ailitonia +@Date : 2021/07/18 13:02 +@FileName : calculator.py +@Project : nonebot2_miya +@Description : 表达式计算模块 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import asyncio +import re +from typing import List, Union +from .exception import CalculateException + + +class BaseCalculator(object): + # 限制单步运算算数上限, 避免进行大数运算 + __power_limit: float = 65536 + __multi_limit: float = 4294967296 + __add_limit: float = 4294967296 + + def __init__(self, expression: str): + self.__raw_expression: str = expression + # 移除空白字符 + __expression = re.sub(r'\s', '', expression) + # 替换中文符号 + __expression = re.sub(r'(', '(', __expression) + __expression = re.sub(r')', ')', __expression) + # 替换运算符 + __expression = re.sub(r'[xX×]', '*', __expression) + __expression = re.sub(r'[÷]', '/', __expression) + self.__expression: str = __expression + + @classmethod + def __handle_sequence_calculate( + cls, calculate_sequence: List[Union[str, int, float]]) -> List[Union[str, int, float]]: + """ + 对拆分后的基本算式执行分步运算 + :param calculate_sequence: 拆分后的基本算式 + :return: calculate_sequence, List[Union[str, int, float]] + """ + # 处理运算 + if '^' in calculate_sequence: + # 处理乘方 + for _index, _obj in reversed(list(enumerate(calculate_sequence))): + # 乘方需要从后向前匹配, 发现乘方则直接进行运算并返回 + if _obj == '^': + num_back = float(calculate_sequence.pop(_index + 1)) + # 移除运算符 + calculate_sequence.pop(_index) + num_front = float(calculate_sequence.pop(_index - 1)) + if num_front >= cls.__power_limit or num_back >= cls.__power_limit: + raise CalculateException('单步运算算数大小超过限制上限', f'{num_front}, {num_back}') + result = num_front ** num_back + calculate_sequence.insert(_index - 1, result) + return calculate_sequence + elif '*' in calculate_sequence or '/' in calculate_sequence: + # 处理乘除法 + for _index, _obj in enumerate(calculate_sequence): + # 从前向后匹配, 发现乘除法则直接进行运算并返回 + if _obj == '*': + num_back = float(calculate_sequence.pop(_index + 1)) + # 移除运算符 + calculate_sequence.pop(_index) + num_front = float(calculate_sequence.pop(_index - 1)) + if num_front >= cls.__multi_limit or num_back >= cls.__multi_limit: + raise CalculateException('单步运算算数大小超过限制上限', f'{num_front}, {num_back}') + result = num_front * num_back + calculate_sequence.insert(_index - 1, result) + return calculate_sequence + elif _obj == '/': + num_back = float(calculate_sequence.pop(_index + 1)) + # 移除运算符 + calculate_sequence.pop(_index) + num_front = float(calculate_sequence.pop(_index - 1)) + if num_front >= cls.__multi_limit or num_back >= cls.__multi_limit: + raise CalculateException('单步运算算数大小超过限制上限', f'{num_front}, {num_back}') + result = num_front / num_back + calculate_sequence.insert(_index - 1, result) + return calculate_sequence + elif '+' in calculate_sequence or '-' in calculate_sequence: + # 处理加减法 + for _index, _obj in enumerate(calculate_sequence): + # 从前向后匹配, 发现加减法则直接进行运算并返回 + if _obj == '+': + num_back = float(calculate_sequence.pop(_index + 1)) + # 移除运算符 + calculate_sequence.pop(_index) + num_front = float(calculate_sequence.pop(_index - 1)) + if num_front >= cls.__add_limit or num_back >= cls.__add_limit: + raise CalculateException('单步运算算数大小超过限制上限', f'{num_front}, {num_back}') + result = num_front + num_back + calculate_sequence.insert(_index - 1, result) + return calculate_sequence + elif _obj == '-': + num_back = float(calculate_sequence.pop(_index + 1)) + # 移除运算符 + calculate_sequence.pop(_index) + num_front = float(calculate_sequence.pop(_index - 1)) + if num_front >= cls.__add_limit or num_back >= cls.__add_limit: + raise CalculateException('单步运算算数大小超过限制上限', f'{num_front}, {num_back}') + result = num_front - num_back + calculate_sequence.insert(_index - 1, result) + return calculate_sequence + else: + # 运算序列中已经没有运算符, 直接返回结果 + if len(calculate_sequence) == 1: + return calculate_sequence + else: + raise CalculateException('执行分步运算错误, 非预期的结果', calculate_sequence) + + @classmethod + def __base_calculate(cls, expression: str) -> float: + """ + 解析并运算不含括号的四则算式 + :param expression: 整数四则运算算式 + :return: float, 运算结果 + """ + # 移除空白字符 + expression = re.sub(r'\s', '', expression) + # 移除首尾括号 + expression = expression.lstrip('(').rstrip(')') + + # 判断运算符合法 + if re.search(r'[^+\-*/^.\d]', expression): + raise CalculateException('非法算式, 包含运算符之外的字符', expression) + + # 拆分所有运算符 + cal_seq: List[Union[str, int, float]] = [x for x in re.split(r'([+\-*/^])', expression) if x] + if re.match(r'[+*/^]', cal_seq[0]) or re.match(r'[+\-*/^]', cal_seq[-1]): + raise CalculateException('非法算式, 算式首尾出现运算符', expression) + + # 处理负数 + for _index, _obj in enumerate(cal_seq): + # 负号在第一位 + if _obj == '-' and _index == 0 and re.match(r'^-?\d+?(\.\d+?)?$', str(cal_seq[_index + 1])): + num = float(cal_seq.pop(_index + 1)) + cal_seq.pop(_index) + cal_seq.insert(_index, -num) + # 负号在中间 + elif _obj == '-' and re.match(r'^[+\-*/^]$', str(cal_seq[_index - 1])) and re.match( + r'^-?\d+?(\.\d+?)?$', str(cal_seq[_index + 1])): + num = float(cal_seq.pop(_index + 1)) + cal_seq.pop(_index) + cal_seq.insert(_index, -num) + + # 负数处理完后再次判断运算符合法 + for _index, _obj in enumerate(cal_seq): + if re.match(r'^[+\-*/^]$', str(_obj)) and re.match(r'^[+\-*/^]$', str(cal_seq[_index - 1])): + raise CalculateException('非法算式, 包含连续的运算符', expression) + + # 分步循环执行运算 + while len(cal_seq) != 1: + cal_seq = cls.__handle_sequence_calculate(cal_seq) + return float(cal_seq[0]) + + def std_calculate(self) -> float: + """ + 解析并运算四则算式 + :return: float, 运算结果 + """ + _expression = self.__expression + # 从最内层括号开始依次执行运算 + while inner_par := re.search(r'(\([+\-*/^.\d]+?\))', _expression): + inner_cal_result = self.__base_calculate(inner_par.group()) + _expression = f'{_expression[:inner_par.start()]}{inner_cal_result}{_expression[inner_par.end():]}' + # 括号遍历完后执行最外层运算 + return self.__base_calculate(_expression) + + async def aio_std_calculate(self) -> float: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self.std_calculate) + return result + + +__all__ = [ + 'BaseCalculator' +] diff --git a/omega_miya/utils/dice_utils/dice.py b/omega_miya/utils/dice_utils/dice.py new file mode 100644 index 00000000..8cc69c1d --- /dev/null +++ b/omega_miya/utils/dice_utils/dice.py @@ -0,0 +1,76 @@ +""" +@Author : Ailitonia +@Date : 2021/07/18 1:28 +@FileName : dice.py +@Project : nonebot2_miya +@Description : 掷骰核心 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import re +import random +import asyncio +from typing import List, Dict + + +class BaseDice(object): + """ + 骰子基类, 模拟真实骰子 + """ + def __init__(self, num: int = 1, side: int = 100): + self.__num: int = num # 骰子个数 + self.__side: int = side # 骰子面数 + self.__result: int = -1 # 掷骰结果 + self.__index: int = -1 # 掷骰子次数索引 + self.__all_result: Dict[int, List[int]] = {} # 记录全部掷骰历史 + + @property + def result(self) -> int: # 本次掷骰点数总和 + return self.__result + + @property + def index(self) -> int: + return self.__index + + @property + def count(self) -> int: + return self.__index + 1 + + @property + def full_result(self) -> List[int]: # 本次掷骰具体详情 + if self.count > 0: + return self.__all_result[self.index] + else: + return [] + + @property + def all_result(self) -> Dict[int, List[int]]: # 历史所有掷骰结果详情 + return self.__all_result + + def dice(self) -> int: + """ + 标准掷骰 + :return: int, 本次掷骰点数总和 + """ + self.__index += 1 + result: List[int] = [] + + for i in range(self.__num): + this_dice_result = random.choice(range(self.__side)) + 1 + result.append(this_dice_result) + + self.__result = sum(result) + self.__all_result.update({self.__index: result}) + + return self.result + + async def aio_dice(self) -> int: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self.dice) + return result + + +__all__ = [ + 'BaseDice' +] diff --git a/omega_miya/utils/dice_utils/exception.py b/omega_miya/utils/dice_utils/exception.py new file mode 100644 index 00000000..630c38b3 --- /dev/null +++ b/omega_miya/utils/dice_utils/exception.py @@ -0,0 +1,28 @@ +""" +@Author : Ailitonia +@Date : 2021/07/18 1:36 +@FileName : exception.py +@Project : nonebot2_miya +@Description : +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + + +class DiceBaseException(Exception): + pass + + +class CalculateException(DiceBaseException): + """ + 计算模块异常 + """ + def __init__(self, reason, expression): + self.reason = reason + self.expression = expression + + def __repr__(self): + return f'' + + def __str__(self): + return self.__repr__() diff --git a/omega_miya/utils/pixiv_utils/__init__.py b/omega_miya/utils/pixiv_utils/__init__.py index c7013a4d..f4ae10ca 100644 --- a/omega_miya/utils/pixiv_utils/__init__.py +++ b/omega_miya/utils/pixiv_utils/__init__.py @@ -1,10 +1,11 @@ -from .pixiv import Pixiv, PixivIllust +from .pixiv import Pixiv, PixivIllust, PixivUser from .pixivision import Pixivision, PixivisionArticle __all__ = [ 'Pixiv', 'PixivIllust', + 'PixivUser', 'Pixivision', 'PixivisionArticle' ] diff --git a/omega_miya/utils/pixiv_utils/pixiv.py b/omega_miya/utils/pixiv_utils/pixiv.py index 1dee5b78..589f73ff 100644 --- a/omega_miya/utils/pixiv_utils/pixiv.py +++ b/omega_miya/utils/pixiv_utils/pixiv.py @@ -1,8 +1,13 @@ import re import os +import pathlib import json import asyncio import aiofiles +import zipfile +import imageio +from io import BytesIO +from typing import Dict, Optional from nonebot import logger, get_driver from omega_miya.utils.Omega_plugin_utils import HttpFetcher, PicEncoder, create_zip_file from omega_miya.utils.Omega_Base import Result @@ -37,95 +42,67 @@ class Pixiv(object): 'Chrome/89.0.4389.114 Safari/537.36'} @classmethod - async def daily_ranking(cls) -> Result.DictResult: - payload_daily = {'format': 'json', 'mode': 'daily', - 'content': 'illust', 'p': 1} - fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_daily_ranking', headers=cls.HEADERS) - daily_ranking_result = await fetcher.get_json(url=cls.RANKING_URL, params=payload_daily) - if daily_ranking_result.error: - return Result.DictResult( - error=True, info=f'Fetch daily ranking failed, {daily_ranking_result.info}', result={}) - - daily_ranking_data = daily_ranking_result.result.get('contents') - if type(daily_ranking_data) != list: - return Result.DictResult( - error=True, info=f'Daily ranking data error, {daily_ranking_result.result}', result={}) - - result = {} - for num in range(len(daily_ranking_data)): - try: - illust_id = daily_ranking_data[num].get('illust_id') - illust_title = daily_ranking_data[num].get('title') - illust_uname = daily_ranking_data[num].get('user_name') - result.update({num: { - 'illust_id': illust_id, - 'illust_title': illust_title, - 'illust_uname': illust_uname - }}) - except Exception as e: - logger.debug(f'Pixiv | Daily ranking data error at {num}, ignored. {str(e)},') - continue - return Result.DictResult(error=False, info='Success', result=result) - - @classmethod - async def weekly_ranking(cls) -> Result.DictResult: - payload_weekly = {'format': 'json', 'mode': 'weekly', - 'content': 'illust', 'p': 1} - fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_weekly_ranking', headers=cls.HEADERS) - weekly_ranking_result = await fetcher.get_json(url=cls.RANKING_URL, params=payload_weekly) - if weekly_ranking_result.error: - return Result.DictResult( - error=True, info=f'Fetch weekly ranking failed, {weekly_ranking_result.info}', result={}) - - weekly_ranking_data = weekly_ranking_result.result.get('contents') - if type(weekly_ranking_data) != list: - return Result.DictResult( - error=True, info=f'Weekly ranking data error, {weekly_ranking_result.result}', result={}) - - result = {} - for num in range(len(weekly_ranking_data)): - try: - illust_id = weekly_ranking_data[num].get('illust_id') - illust_title = weekly_ranking_data[num].get('title') - illust_uname = weekly_ranking_data[num].get('user_name') - result.update({num: { - 'illust_id': illust_id, - 'illust_title': illust_title, - 'illust_uname': illust_uname - }}) - except Exception as e: - logger.debug(f'Pixiv | Weekly ranking data error at {num}, ignored. {str(e)},') - continue - return Result.DictResult(error=False, info='Success', result=result) + def parse_pid_from_url(cls, text: str, *, url_mode: bool = False) -> Optional[int]: + if url_mode: + # 分别匹配不同格式pivix链接格式 仅能匹配特定 url 格式的字符串 + if url_new := re.search(r'^https?://.*?pixiv\.net/(artworks|i)/(\d+?)$', text): + return int(url_new.groups()[1]) + elif url_old := re.search(r'^https?://.*?pixiv\.net.*?illust_id=(\d+?)(&mode=\w+?)?$', text): + return int(url_old.groups()[0]) + else: + return None + else: + # 分别匹配不同格式pivix链接格式 可匹配任何字符串中的url + if url_new := re.search(r'https?://.*?pixiv\.net/(artworks|i)/(\d+)', text): + return int(url_new.groups()[1]) + elif url_old := re.search(r'https?://.*?pixiv\.net.*?illust_id=(\d+)', text): + return int(url_old.groups()[0]) + else: + return None @classmethod - async def monthly_ranking(cls) -> Result.DictResult: - payload_monthly = {'format': 'json', 'mode': 'monthly', - 'content': 'illust', 'p': 1} - fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_monthly_ranking', headers=cls.HEADERS) - monthly_ranking_result = await fetcher.get_json(url=cls.RANKING_URL, params=payload_monthly) - if monthly_ranking_result.error: + async def get_ranking( + cls, + mode: str, + *, + page: int = 1, + content: Optional[str] = 'illust' + ) -> Result.DictResult: + """ + 获取 Pixiv 排行榜 + :param mode: 排行榜类型 + :param page: 页数 + :param content: 作品类型 + :return: + """ + if not content: + payload = {'format': 'json', 'mode': mode, 'p': page} + else: + payload = {'format': 'json', 'mode': mode, 'content': content, 'p': page} + fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_get_ranking', headers=cls.HEADERS) + ranking_result = await fetcher.get_json(url=cls.RANKING_URL, params=payload) + if ranking_result.error: return Result.DictResult( - error=True, info=f'Fetch monthly ranking failed, {monthly_ranking_result.info}', result={}) + error=True, info=f'Fetch ranking result failed, {ranking_result.info}', result={}) - monthly_ranking_data = monthly_ranking_result.result.get('contents') - if type(monthly_ranking_data) != list: + ranking_data = ranking_result.result.get('contents') + if type(ranking_data) != list: return Result.DictResult( - error=True, info=f'Monthly ranking data error, {monthly_ranking_result.result}', result={}) + error=True, info=f'Getting ranking data error, {ranking_result.result}', result={}) result = {} - for num in range(len(monthly_ranking_data)): + for num in range(len(ranking_data)): try: - illust_id = monthly_ranking_data[num].get('illust_id') - illust_title = monthly_ranking_data[num].get('title') - illust_uname = monthly_ranking_data[num].get('user_name') + illust_id = ranking_data[num].get('illust_id') + illust_title = ranking_data[num].get('title') + illust_uname = ranking_data[num].get('user_name') result.update({num: { 'illust_id': illust_id, 'illust_title': illust_title, 'illust_uname': illust_uname }}) except Exception as e: - logger.debug(f'Pixiv | Monthly ranking data error at {num}, ignored. {repr(e)},') + logger.debug(f'Pixiv | Getting ranking data error at {num}, ignored. {repr(e)},') continue return Result.DictResult(error=False, info='Success', result=result) @@ -133,12 +110,19 @@ async def monthly_ranking(cls) -> Result.DictResult: class PixivIllust(Pixiv): def __init__(self, pid: int): self.__pid: int = pid - self.__is_loaded: bool = False + self.__is_data_loaded: bool = False + self.__is_pic_loaded: bool = False + self.__is_downloaded: bool = False self.__illust_data: dict = {} + self.__pic: bytes = b'' + self.__downloaded_file_path: str = '' # 获取作品完整信息(pixiv api 获取 json) # 返回格式化后的作品信息 async def get_illust_data(self) -> Result.DictResult: + if self.__is_data_loaded: + return Result.DictResult(error=False, info='Success', result=self.__illust_data) + illust_url = f'{self.ILLUST_DATA_URL}{self.__pid}' illust_artworks_url = f'{self.ILLUST_ARTWORK_URL}{self.__pid}' @@ -174,6 +158,8 @@ async def get_illust_data(self) -> Result.DictResult: userid = int(illust_data['body']['userId']) username = str(illust_data['body']['userName']) url = f'{self.ILLUST_ARTWORK_URL}{self.__pid}' + width = int(illust_data['body']['width']) + height = int(illust_data['body']['height']) page_count = int(illust_data['body']['pageCount']) illust_orig_url = str(illust_data['body']['urls']['original']) illust_regular_url = str(illust_data['body']['urls']['regular']) @@ -182,6 +168,11 @@ async def get_illust_data(self) -> Result.DictResult: re_std_description_s2 = r'<[^>]+>' illust_description = re.sub(re_std_description_s1, '\n', illust_description) illust_description = re.sub(re_std_description_s2, '', illust_description) + # 作品相关统计信息 + like_count = int(illust_data['body']['likeCount']) + bookmark_count = int(illust_data['body']['bookmarkCount']) + view_count = int(illust_data['body']['viewCount']) + comment_count = int(illust_data['body']['commentCount']) # 处理作品tag illusttag = [] @@ -197,6 +188,8 @@ async def get_illust_data(self) -> Result.DictResult: continue if 'R-18' in illusttag: is_r18 = True + elif 'R-18G' in illusttag: + is_r18 = True else: is_r18 = False @@ -207,7 +200,10 @@ async def get_illust_data(self) -> Result.DictResult: 'regular': [], 'original': [], } + # PixivPage数据库用, 图片列表原始数据 + origin_pages = {} if not illust_pages.get('error') and illust_pages: + origin_pages.update(dict(enumerate([x.get('urls') for x in illust_pages.get('body')]))) for item in illust_pages.get('body'): all_url.get('thumb_mini').append(item['urls']['thumb_mini']) all_url.get('small').append(item['urls']['small']) @@ -241,10 +237,17 @@ async def get_illust_data(self) -> Result.DictResult: 'uid': userid, 'uname': username, 'url': url, + 'width': width, + 'height': height, + 'like_count': like_count, + 'bookmark_count': bookmark_count, + 'view_count': view_count, + 'comment_count': comment_count, 'page_count': page_count, 'orig_url': illust_orig_url, 'regular_url': illust_regular_url, 'all_url': all_url, + 'illust_pages': origin_pages, 'ugoira_meta': ugoira_meta, 'description': illust_description, 'tags': illusttag, @@ -252,7 +255,7 @@ async def get_illust_data(self) -> Result.DictResult: } # 保存对象状态便于其他方法调用 - self.__is_loaded = True + self.__is_data_loaded = True self.__illust_data.update(result) return Result.DictResult(error=False, info='Success', result=result) @@ -260,9 +263,8 @@ async def get_illust_data(self) -> Result.DictResult: logger.error(f'PixivIllust | Parse illust data failed, error: {repr(e)}') return Result.DictResult(error=True, info=f'Parse illust data failed', result={}) - # 图片转base64 - async def pic_2_base64(self, original: bool = False) -> Result.TextResult: - if self.__is_loaded: + async def get_format_info_msg(self, desc_len: int = 32) -> Result.TextResult: + if self.__is_data_loaded: illust_data = self.__illust_data else: illust_data_result = await self.get_illust_data() @@ -281,7 +283,21 @@ async def pic_2_base64(self, original: bool = False) -> Result.TextResult: if not description: info = f'「{title}」/「{author}」\n{tags}\n{url}' else: - info = f'「{title}」/「{author}」\n{tags}\n{url}\n----------------\n{description[:28]}......' + info = f'「{title}」/「{author}」\n{tags}\n{url}\n----------------\n{description[:desc_len]}......' + return Result.TextResult(error=False, info='Success', result=info) + + # 加载作品图片 + async def load_illust_pic(self, original: bool = False) -> Result.BytesResult: + if self.__is_pic_loaded: + return Result.BytesResult(error=False, info='Success', result=self.__pic) + + if self.__is_data_loaded: + illust_data = self.__illust_data + else: + illust_data_result = await self.get_illust_data() + if illust_data_result.error: + return Result.BytesResult(error=True, info='Fetch illust data failed', result=b'') + illust_data = dict(illust_data_result.result) if original: url = illust_data.get('orig_url') @@ -295,18 +311,168 @@ async def pic_2_base64(self, original: bool = False) -> Result.TextResult: 'sec-fetch-site': 'cross-site' }) - fetcher = HttpFetcher(timeout=30, attempt_limit=2, flag='pixiv_utils_get_image', headers=headers) + fetcher = HttpFetcher(timeout=30, attempt_limit=2, flag='pixiv_utils_load_illust_pic', headers=headers) bytes_result = await fetcher.get_bytes(url=url) if bytes_result.error: - return Result.TextResult(error=True, info='Image download failed', result='') + return Result.BytesResult(error=True, info='Image download failed', result=b'') + else: + # 保存对象状态便于其他方法调用 + self.__is_pic_loaded = True + self.__pic = bytes_result.result + return Result.BytesResult(error=False, info='Success', result=bytes_result.result) - encode_result = PicEncoder.bytes_to_b64(image=bytes_result.result) + # 图片转base64 + async def get_base64(self, original: bool = False) -> Result.TextResult: + if self.__is_pic_loaded: + illust_pic = self.__pic + else: + illust_pic_result = await self.load_illust_pic(original=original) + if illust_pic_result.error: + return Result.TextResult(error=True, info='Image download failed', result='') + illust_pic = illust_pic_result.result + encode_result = PicEncoder.bytes_to_b64(image=illust_pic) if encode_result.success(): - return Result.TextResult(error=False, info=info, result=encode_result.result) + return Result.TextResult(error=False, info='Success', result=encode_result.result) else: return Result.TextResult(error=True, info=encode_result.info, result='') + # 图片转fileurl + async def get_file(self, original: bool = False) -> Result.TextResult: + folder_path = os.path.abspath(os.path.join(TMP_PATH, 'pixiv_illust')) + file_name = f'{self.__pid}_original_{original}' + file_path = os.path.abspath(os.path.join(folder_path, file_name)) + # 如果已经存在则直接返回 + if os.path.exists(file_path): + file_url = pathlib.Path(file_path).as_uri() + return Result.TextResult(error=False, info='Success', result=file_url) + + # 没有的话再下载并保存文件 + if self.__is_pic_loaded: + illust_pic = self.__pic + else: + illust_pic_result = await self.load_illust_pic(original=original) + if illust_pic_result.error: + return Result.TextResult(error=True, info='Image download failed', result='') + illust_pic = illust_pic_result.result + # 检查保存文件路径 + if not os.path.exists(folder_path): + os.makedirs(folder_path) + try: + async with aiofiles.open(file_path, 'wb') as aio_f: + await aio_f.write(illust_pic) + file_url = pathlib.Path(file_path).as_uri() + return Result.TextResult(error=False, info='Success', result=file_url) + except Exception as e: + return Result.TextResult(error=True, info=repr(e), result='') + + async def get_sending_msg(self, *, mode: str = 'file') -> Result.TextTupleResult: + """ + :param mode: 发送图片方式 + file: 下载为本地文件发送 + base64: 使用base64发送 + :return: Tuple[image_url: str, info_msg: str] + """ + if mode == 'file': + img_result = await self.get_file() + elif mode == 'base64': + img_result = await self.get_base64() + else: + return Result.TextTupleResult(error=True, info='Illegal mode', result=()) + + if img_result.error: + return Result.TextTupleResult( + error=True, info=f'Getting img failed, error: {img_result.info}', result=()) + + info_msg_result = await self.get_format_info_msg() + if info_msg_result.error: + return Result.TextTupleResult( + error=True, info=f'Getting info msg failed, error: {info_msg_result.info}', result=()) + + return Result.TextTupleResult(error=False, info='Success', result=(img_result.result, info_msg_result.result)) + + def __load_ugoira_pics(self, file_path: str) -> Dict[str, bytes]: + if not self.__is_data_loaded: + raise RuntimeError('Illust data not loaded!') + if not os.path.exists(file_path): + raise RuntimeError(f'File: {file_path}, Not found.') + result_list = {} + with zipfile.ZipFile(file_path, 'r') as zip_f: + name_list = zip_f.namelist() + for file_name in name_list: + result_list.update({ + file_name: zip_f.open(file_name, 'r').read() + }) + return result_list + + def __generate_ugoira_gif(self, ugoira_pics: Dict[str, bytes]) -> bytes: + if not self.__is_data_loaded: + raise RuntimeError('Illust data not loaded!') + frames_list = [] + sum_delay = [] + for file, delay in [(item['file'], item['delay']) for item in self.__illust_data['ugoira_meta']['frames']]: + frames_list.append(imageio.imread(ugoira_pics[file])) + sum_delay.append(delay) + avg_delay = sum(sum_delay) / len(sum_delay) + avg_duration = avg_delay / 1000 + with BytesIO() as bytes_f: + imageio.mimsave(bytes_f, frames_list, 'GIF', duration=avg_duration) + return bytes_f.getvalue() + + async def __prepare_ugoira_gif(self) -> bytes: + if self.__is_data_loaded: + illust_data = self.__illust_data + else: + illust_data_result = await self.get_illust_data() + if illust_data_result.error: + raise RuntimeError('Fetch illust data failed') + illust_data = dict(illust_data_result.result) + + illust_type = illust_data.get('illust_type') + if illust_type != 2: + raise RuntimeError('Illust not ugoira!') + + ugoira_zip_dl_url = illust_data.get('ugoira_meta').get('originalsrc') + if not ugoira_zip_dl_url: + raise RuntimeError('Can not get ugoira download url!') + + zip_file_name = os.path.split(ugoira_zip_dl_url)[-1] + download_result = await self.download_illust() + if download_result.error: + raise RuntimeError(f'Download ugoira Illust failed: {download_result.info}') + + folder_path = os.path.split(download_result.result)[0] + ugoira_zip_path = os.path.abspath(os.path.join(folder_path, zip_file_name)) + + loop = asyncio.get_running_loop() + ugoira_pics = await loop.run_in_executor(None, self.__load_ugoira_pics, ugoira_zip_path) + gif_bytes = await loop.run_in_executor(None, self.__generate_ugoira_gif, ugoira_pics) + return gif_bytes + + async def get_ugoira_gif_base64(self) -> Result.TextResult: + try: + gif_bytes = await self.__prepare_ugoira_gif() + base64_result = PicEncoder.bytes_to_b64(image=gif_bytes) + return base64_result + except Exception as e: + return Result.TextResult(error=True, info=repr(e), result='') + + async def get_ugoira_gif_filepath(self) -> Result.TextResult: + try: + gif_bytes = await self.__prepare_ugoira_gif() + folder_path = os.path.abspath(os.path.join(TMP_PATH, 'pixiv_illust')) + # 检查保存文件路径 + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_name = f'{self.__pid}.gif' + file_path = os.path.abspath(os.path.join(folder_path, file_name)) + async with aiofiles.open(file_path, 'wb') as aio_f: + await aio_f.write(gif_bytes) + file_url = pathlib.Path(file_path).as_uri() + return Result.TextResult(error=False, info='Success', result=file_url) + except Exception as e: + return Result.TextResult(error=True, info=repr(e), result='') + async def download_illust(self, page: int = None) -> Result.TextResult: """ :param page: 仅下载特定页码 @@ -314,20 +480,24 @@ async def download_illust(self, page: int = None) -> Result.TextResult: if page and page < 1: page = None - illust_data_result = await self.get_illust_data() - if illust_data_result.error: - return Result.TextResult(error=True, info='Fetch illust data failed', result='') + if self.__is_data_loaded: + illust_data = self.__illust_data + else: + illust_data_result = await self.get_illust_data() + if illust_data_result.error: + return Result.TextResult(error=True, info='Fetch illust data failed', result='') + illust_data = dict(illust_data_result.result) download_url_list = [] - page_count = illust_data_result.result.get('page_count') - illust_type = illust_data_result.result.get('illust_type') + page_count = illust_data.get('page_count') + illust_type = illust_data.get('illust_type') if illust_type == 2: # 作品类型为动图 - download_url_list.append(illust_data_result.result.get('ugoira_meta').get('originalsrc')) + download_url_list.append(illust_data.get('ugoira_meta').get('originalsrc')) if page_count == 1: - download_url_list.append(illust_data_result.result.get('orig_url')) + download_url_list.append(illust_data.get('orig_url')) else: - download_url_list.extend(illust_data_result.result.get('all_url').get('original')) + download_url_list.extend(illust_data.get('all_url').get('original')) if page and page <= page_count: download_url_list = [download_url_list[page - 1]] @@ -351,6 +521,8 @@ async def download_illust(self, page: int = None) -> Result.TextResult: download_result = await fetcher.download_file(url=download_url_list[0], path=file_path, file_name=file_name) if download_result.success(): + self.__is_downloaded = True + self.__downloaded_file_path = download_result.result return Result.TextResult(error=False, info=file_name, result=download_result.result) else: return Result.TextResult(error=True, info=download_result.info, result='') @@ -369,8 +541,8 @@ async def download_illust(self, page: int = None) -> Result.TextResult: # 动图额外保存原始ugoira_meta信息 if illust_type == 2: - pid = illust_data_result.result.get('pid') - ugoira_meta = illust_data_result.result.get('ugoira_meta') + pid = illust_data.get('pid') + ugoira_meta = illust_data.get('ugoira_meta') ugoira_meta_file = os.path.abspath(os.path.join(file_path, f'{pid}_ugoira_meta')) async with aiofiles.open(ugoira_meta_file, 'w') as f: await f.write(json.dumps(ugoira_meta)) @@ -378,12 +550,136 @@ async def download_illust(self, page: int = None) -> Result.TextResult: # 打包 zip_result = await create_zip_file(files=downloaded_list, file_path=file_path, file_name=str(self.__pid)) + if zip_result.success(): + self.__is_downloaded = True + self.__downloaded_file_path = zip_result.result return zip_result else: return Result.TextResult(error=True, info='Get illust url failed', result='') + async def get_recommend(self, *, init_limit: int = 18, lang: str = 'zh') -> Result.DictResult: + """ + 获取作品对应的相关作品推荐 + :param init_limit: 初始化作品推荐时首次加载的作品数量 + :param lang: 语言 + :return: DictResult + illusts: List[Dict], 首次加载的推荐作品的详细信息 + nextIds: List, 剩余未加载推荐作品的pid列表 + details: Dict, 所有推荐作品获取关联信息 + """ + recommend_url = f'{self.ILLUST_DATA_URL}{self.__pid}/recommend/init' + illust_artworks_url = f'{self.ILLUST_ARTWORK_URL}{self.__pid}' + + headers = self.HEADERS.copy() + headers.update({ + 'accept': 'application/json', + 'referer': illust_artworks_url + }) + params = {'limit': init_limit, 'lang': lang} + fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_illust_recommend', headers=headers, cookies=COOKIES) + recommend_data_result = await fetcher.get_json(url=recommend_url, params=params) + + if recommend_data_result.error: + return Result.DictResult( + error=True, info=f'Fetch illust recommend failed, {recommend_data_result.info}', result={}) + + # 检查返回状态 + if recommend_data_result.result.get('error') or not recommend_data_result.result: + return Result.DictResult(error=True, info=f'PixivApiError: {recommend_data_result.result}', result={}) + + # 直接返回原始结果 + return Result.DictResult(error=False, info='Success', result=recommend_data_result.result.get('body')) + + +class PixivUser(Pixiv): + def __init__(self, uid: int): + self.__uid: int = uid + + async def get_info(self) -> Result.DictResult: + user_info_url = f'https://www.pixiv.net/ajax/user/{self.__uid}' + + headers = self.HEADERS.copy() + headers.update({'referer': f'https://www.pixiv.net/users/{self.__uid}'}) + + fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_user', headers=headers, cookies=COOKIES) + + # 获取用户信息 + params = {'lang': 'zh'} + user_info_result = await fetcher.get_json(url=user_info_url, params=params) + if user_info_result.error: + return Result.DictResult(error=True, info=f'Fetch user info failed, {user_info_result.info}', result={}) + + # 检查返回状态 + if user_info_result.result.get('error') or not user_info_result.result: + return Result.DictResult(error=True, info=f'PixivApiError: {user_info_result.result}', result={}) + + user_info = user_info_result.result + + try: + # 处理用户基本信息 + name = user_info['body'].get('name') + image = user_info['body'].get('image') + image_big = user_info['body'].get('imageBig') + partial = user_info['body'].get('partial') + premium = user_info['body'].get('premium') + sketch_live_id = user_info['body'].get('sketchLiveId') + sketch_lives = user_info['body'].get('sketchLives') + user_id = user_info['body'].get('userId') + + result = { + 'name': name, + 'image': image, + 'image_big': image_big, + 'partial': partial, + 'premium': premium, + 'sketch_live_id': sketch_live_id, + 'sketch_lives': sketch_lives, + 'user_id': user_id + } + return Result.DictResult(error=False, info='Success', result=result) + except Exception as e: + logger.error(f'PixivUser | Parse user info failed, error: {repr(e)}') + return Result.DictResult(error=True, info=f'Parse user info failed', result={}) + + async def get_artworks_info(self) -> Result.DictResult: + user_data_url = f'https://www.pixiv.net/ajax/user/{self.__uid}/profile/all' + + headers = self.HEADERS.copy() + headers.update({'referer': f'https://www.pixiv.net/users/{self.__uid}'}) + + fetcher = HttpFetcher(timeout=10, flag='pixiv_utils_user', headers=headers, cookies=COOKIES) + + # 获取作品信息 + params = {'lang': 'zh'} + user_data_result = await fetcher.get_json(url=user_data_url, params=params) + if user_data_result.error: + return Result.DictResult(error=True, info=f'Fetch user data failed, {user_data_result.info}', result={}) + + # 检查返回状态 + if user_data_result.result.get('error') or not user_data_result.result: + return Result.DictResult(error=True, info=f'PixivApiError: {user_data_result.result}', result={}) + + user_data = user_data_result.result + + try: + # 处理作品基本信息 + illust_list = [int(pid) for pid in dict(user_data['body']['illusts']).keys()] + manga_list = [int(pid) for pid in dict(user_data['body']['manga']).keys()] + novels_list = [int(nid) for nid in dict(user_data['body']['novels']).keys()] + + result = { + 'illust_list': illust_list, + 'manga_list': manga_list, + 'novels_list': novels_list + } + return Result.DictResult(error=False, info='Success', result=result) + except Exception as e: + logger.error(f'PixivUser | Parse user data failed, error: {repr(e)}') + return Result.DictResult(error=True, info=f'Parse user data failed', result={}) + __all__ = [ 'Pixiv', - 'PixivIllust' + 'PixivIllust', + 'PixivUser' ] diff --git a/omega_miya/utils/pixiv_utils/pixivision.py b/omega_miya/utils/pixiv_utils/pixivision.py index 7021638e..3815a8a1 100644 --- a/omega_miya/utils/pixiv_utils/pixivision.py +++ b/omega_miya/utils/pixiv_utils/pixivision.py @@ -127,17 +127,11 @@ async def get_article_info(self) -> Result.DictResult: url = illust_info.attrs['href'] # info = illust_info.get_text(strip=True) # 识别pid - text_o = re.findall(r'illust_id=[0-9]+', url) - text_n = re.findall(r'net/artworks/[0-9]+', url) - text_p = re.findall(r'pixiv\.net/i/[0-9]+', url) - if text_o: - pid = re.search(r'[0-9]+', text_o[0]).group() + if url_new := re.search(r'https?://.*?pixiv\.net/(artworks|i)/(\d+)', url): + pid = int(url_new.groups()[1]) url = f'https://www.pixiv.net/artworks/{pid}' - elif text_n: - pid = re.search(r'[0-9]+', text_n[0]).group() - url = f'https://www.pixiv.net/artworks/{pid}' - elif text_p: - pid = re.search(r'[0-9]+', text_p[0]).group() + elif url_old := re.search(r'https?://.*?pixiv\.net.*?illust_id=(\d+)', url): + pid = int(url_old.groups()[0]) url = f'https://www.pixiv.net/artworks/{pid}' else: logger.debug(f'PixivisionArticle | Illust in article {self.__aid} not found, ignored.') diff --git a/omega_miya/utils/tencent_cloud_api/__init__.py b/omega_miya/utils/tencent_cloud_api/__init__.py index 76e13bdb..ea41c0ba 100644 --- a/omega_miya/utils/tencent_cloud_api/__init__.py +++ b/omega_miya/utils/tencent_cloud_api/__init__.py @@ -1,6 +1,8 @@ from .nlp import TencentNLP +from .tmt import TencentTMT __all__ = [ - 'TencentNLP' + 'TencentNLP', + 'TencentTMT' ] diff --git a/omega_miya/utils/tencent_cloud_api/cloud_api.py b/omega_miya/utils/tencent_cloud_api/cloud_api.py index aae14651..6d4c156d 100644 --- a/omega_miya/utils/tencent_cloud_api/cloud_api.py +++ b/omega_miya/utils/tencent_cloud_api/cloud_api.py @@ -2,9 +2,9 @@ import hashlib import hmac import datetime -from dataclasses import dataclass from typing import Dict, Any import nonebot +from omega_miya.utils.Omega_Base import Result from omega_miya.utils.Omega_plugin_utils import HttpFetcher global_config = nonebot.get_driver().config @@ -13,18 +13,6 @@ class TencentCloudApi(object): - @dataclass - class ApiRes: - error: bool - info: str - result: dict - - def success(self) -> bool: - if not self.error: - return True - else: - return False - def __init__(self, secret_id: str, secret_key: str, @@ -53,7 +41,7 @@ def __upgrade_signed_header(self, region: str, version: str, payload: Dict[str, Any]): - self.__headers = { + self.__headers.update({ 'Authorization': self.__sign_v3(payload=payload), 'Content-Type': 'application/json', 'Host': self.__host, @@ -61,7 +49,7 @@ def __upgrade_signed_header(self, 'X-TC-Region': region, 'X-TC-Timestamp': str(self.__request_timestamp), 'X-TC-Version': version - } + }) def __canonical_request(self, payload: Dict[str, Any], @@ -115,13 +103,13 @@ def __sign(key, msg): return authorization - async def post_request(self, action: str, region: str, version: str, payload: Dict[str, Any]) -> ApiRes: + async def post_request(self, action: str, region: str, version: str, payload: Dict[str, Any]) -> Result.DictResult: self.__upgrade_signed_header(action=action, region=region, version=version, payload=payload) fetcher = HttpFetcher(timeout=10, flag=f'tencent_cloud_api_{action}', headers=self.__headers) result = await fetcher.post_json(url=self.__endpoint, json=payload) - return self.ApiRes(error=result.error, info=result.info, result=result.result) + return Result.DictResult(error=result.error, info=result.info, result=result.result) __all__ = [ diff --git a/omega_miya/utils/tencent_cloud_api/nlp.py b/omega_miya/utils/tencent_cloud_api/nlp.py index 38841c86..09ca8643 100644 --- a/omega_miya/utils/tencent_cloud_api/nlp.py +++ b/omega_miya/utils/tencent_cloud_api/nlp.py @@ -1,5 +1,7 @@ import re import json +from typing import Optional +from omega_miya.utils.Omega_Base import Result from .cloud_api import SECRET_ID, SECRET_KEY, TencentCloudApi @@ -8,7 +10,7 @@ def __init__(self): self.__secret_id = SECRET_ID self.__secret_key = SECRET_KEY - async def chat_bot(self, text: str, flag: int = 0) -> TencentCloudApi.ApiRes: + async def chat_bot(self, text: str, flag: int = 0) -> Result.TextResult: payload = { 'Query': text, 'Flag': flag} @@ -19,16 +21,20 @@ async def chat_bot(self, text: str, flag: int = 0) -> TencentCloudApi.ApiRes: result = await api.post_request( action='ChatBot', version='2019-04-08', region='ap-guangzhou', payload=payload) - if not result.error: + if result.success(): if result.result['Response'].get('Error'): - result.error = True - result.info = result.result['Response'].get('Error') + return Result.TextResult( + error=True, info=f"API error: {result.result['Response'].get('Error')}", result='') else: - result.info = f"Confidence: {result.result['Response']['Confidence']}" - result.result = result.result['Response']['Reply'] - return result + return Result.TextResult( + error=False, + info=f"Success with confidence: {result.result['Response']['Confidence']}", + result=result.result['Response']['Reply'] + ) + else: + return Result.TextResult(error=True, info=result.info, result='') - async def describe_entity(self, entity_name: str, attr: str = None) -> TencentCloudApi.ApiRes: + async def describe_entity(self, entity_name: str, attr: str = '简介') -> Result.TextResult: payload = {'EntityName': entity_name} api = TencentCloudApi( secret_id=self.__secret_id, @@ -37,30 +43,24 @@ async def describe_entity(self, entity_name: str, attr: str = None) -> TencentCl result = await api.post_request( action='DescribeEntity', version='2019-04-08', region='ap-guangzhou', payload=payload) - if not result.error: - result.result = result.result + if result.success(): if result.result['Response'].get('Error'): - result.error = True - result.info = result.result['Response'].get('Error') + return Result.TextResult( + error=True, info=f"API error: {result.result['Response'].get('Error')}", result='') else: - if attr: - content = json.loads(result.result['Response']['Content']) - attr_content = content.get(attr) - if attr_content: - attr_text = '\n'.join([re.sub(r'\s{2,}', '', x.get('Name')) for x in attr_content]) - result.result = attr_text - else: - result.error = True - result.info = 'Attributes not found' + content = json.loads(result.result['Response']['Content']) + attr_content = content.get(attr) + if isinstance(attr_content, list): + attr_text = ';\n'.join([ + re.sub(r'\s{2,}', '', str(x.get('Name'))).replace('|@|', '\n') for x in attr_content + ]) + return Result.TextResult(error=False, info='Success', result=attr_text) else: - text = json.loads(result.result['Response']['Content'])['简介'][0]['Name'] - text = re.sub(r'\s{2,}', '', str(text)) - text = text.replace('|@|', '\n') - # text = re.sub(r'\s{2,}', '', str(text)).split('|@|')[0] - result.result = text - return result - - async def describe_relation(self, left_entity_name: str, right_entity_name: str) -> TencentCloudApi.ApiRes: + return Result.TextResult(error=True, info='Attributes not found', result='') + else: + return Result.TextResult(error=True, info=result.info, result='') + + async def describe_relation(self, left_entity_name: str, right_entity_name: str) -> Result.TextResult: payload = {'LeftEntityName': left_entity_name, 'RightEntityName': right_entity_name} api = TencentCloudApi( secret_id=self.__secret_id, @@ -68,18 +68,138 @@ async def describe_relation(self, left_entity_name: str, right_entity_name: str) host='nlp.tencentcloudapi.com') result = await api.post_request( action='DescribeRelation', version='2019-04-08', region='ap-guangzhou', payload=payload) - if not result.error: + + if result.success(): if result.result['Response'].get('Error'): - result.error = True - result.info = result.result['Response'].get('Error') + return Result.TextResult( + error=True, info=f"API error: {result.result['Response'].get('Error')}", result='') else: - content = result.result = result.result['Response']['Content'] - res_list = \ - [(x['Object'][0]['Name'][0], x['Subject'][0]['Name'][0], x['Relation']) for x in content[:-1]] + content = result.result['Response']['Content'] + res_list = [ + (x['Object'][0]['Name'][0], x['Subject'][0]['Name'][0], x['Relation']) for x in content[:-1] + ] msg = ';\n'.join([f'{x[0]}是{x[1]}的{x[2]}' for x in res_list]) - result.info = f'get relation: {len(res_list)}' - result.result = msg - return result + return Result.TextResult(error=False, info=f'Get relation: {len(res_list)}', result=msg) + else: + return Result.TextResult(error=True, info=result.info, result='') + + async def sentiment_analysis(self, text: str, *, flag: int = 4, mode: str = '2class') -> Result.DictResult: + """ + 情感分析 + :param text: 待分析的文本(仅支持UTF-8格式, 不超过200字) + :param flag: 待分析文本所属的类型, 仅当输入参数Mode取值为2class时有效(默认取4值) + 1: 商品评论类 + 2: 社交类 + 3: 美食酒店类 + 4: 通用领域类 + :param mode: 情感分类模式选项, 可取2class或3class(默认值为2class) + 2class: 返回正负面二分类情感结果 + 3class: 返回正负面及中性三分类情感结果 + :return: DictResult + Positive: Float, 正面情感概率 + Neutral: Float, 中性情感概率, 当输入参数Mode取值为3class时有效, 否则值为空。注意: 此字段可能返回 null, 表示取不到有效值。 + Negative: Float, 负面情感概率 + Sentiment: String, 情感分类结果 + positive: 表示正面情感 + negative: 表示负面情感 + neutral: 表示中性、无情感 + RequestId: String, 唯一请求 ID, 每次请求都会返回。定位问题时需要提供该次请求的 RequestId。 + """ + payload = {'Text': text, 'Flag': flag, 'Mode': mode} + api = TencentCloudApi( + secret_id=self.__secret_id, + secret_key=self.__secret_key, + host='nlp.tencentcloudapi.com') + result = await api.post_request( + action='SentimentAnalysis', version='2019-04-08', region='ap-guangzhou', payload=payload) + + if result.success(): + if result.result['Response'].get('Error'): + return Result.DictResult( + error=True, info=f"API error: {result.result['Response'].get('Error')}", result={}) + else: + response = dict(result.result['Response']) + return Result.DictResult(error=False, info='Success', result=response) + else: + return Result.DictResult(error=True, info=result.info, result={}) + + async def sentiment_tendency(self, text: str, *, flag: int = 4, mode: str = '3class') -> Result.IntResult: + """ + 使用sentiment_analysis情感分析后直接返回情感倾向标签 + :param text: 待分析的文本(仅支持UTF-8格式, 不超过200字) + :param flag: 待分析文本所属的类型, 同sentiment_analysis, 默认为4 + :param mode: 情感分类模式选项, 同sentiment_analysis, 默认为3class + :return: IntResult + 1: positive, 表示正面情感 + 0: neutral, 表示中性、无情感 + -1: negative, 表示负面情感 + """ + result = await self.sentiment_analysis(text=text, flag=flag, mode=mode) + if result.success(): + sentiment = result.result.get('Sentiment') + if sentiment == 'positive': + return Result.IntResult(error=False, info='Success', result=1) + elif sentiment == 'neutral': + return Result.IntResult(error=False, info='Success', result=0) + elif sentiment == 'negative': + return Result.IntResult(error=False, info='Success', result=-1) + else: + return Result.IntResult(error=True, info=f'Sentiment result not found', result=-100) + else: + return Result.IntResult(error=True, info=result.info, result=-101) + + async def lexical_analysis(self, text: str, *, dict_id: Optional[str] = None, flag: int = 2) -> Result.DictResult: + """ + 词法分析 + :param text: 待分析的文本(仅支持UTF-8格式, 不超过500字) + :param dict_id: 指定要加载的自定义词库ID + :param flag: 词法分析模式(默认取2值) + 1: 高精度(混合粒度分词能力) + 2: 高性能(单粒度分词能力) + :return: DictResult + NerTokens: Array of NerToken, 命名实体识别结果。取值范围, 注意:此字段可能返回 null, 表示取不到有效值 + PER: 表示人名, 如刘德华、贝克汉姆 + LOC: 表示地名, 如北京、华山 + ORG: 表示机构团体名, 如腾讯、最高人民法院、人大附中 + PRODUCTION: 表示产品名, 如QQ、微信、iPhone + PosTokens: Array of PosToken, 分词&词性标注结果(词性表请参见附录) + RequestId: String, 唯一请求 ID, 每次请求都会返回。定位问题时需要提供该次请求的 RequestId + """ + payload = {'Text': text, 'DictId': dict_id, 'Flag': flag} if dict_id else {'Text': text, 'Flag': flag} + api = TencentCloudApi( + secret_id=self.__secret_id, + secret_key=self.__secret_key, + host='nlp.tencentcloudapi.com') + result = await api.post_request( + action='LexicalAnalysis', version='2019-04-08', region='ap-guangzhou', payload=payload) + + if result.success(): + if result.result['Response'].get('Error'): + return Result.DictResult( + error=True, info=f"API error: {result.result['Response'].get('Error')}", result={}) + else: + response = dict(result.result['Response']) + return Result.DictResult(error=False, info='Success', result=response) + else: + return Result.DictResult(error=True, info=result.info, result={}) + + async def participle_and_tagging( + self, text: str, *, dict_id: Optional[str] = None, flag: int = 2) -> Result.TupleListResult: + """ + 使用lexical_analysis进行分词和词性标注 + :param text: 待分析的文本(仅支持UTF-8格式, 不超过500字) + :param dict_id: 指定要加载的自定义词库ID + :param flag: 词法分析模式(默认取2值) + :return: TupleListResult[Word: str, Pos: str] + Word: 基础词 + Pos: 词性 + """ + result = await self.lexical_analysis(text=text, dict_id=dict_id, flag=flag) + if result.success(): + participle_result = [(x.get('Word'), x.get('Pos')) for x in result.result.get('PosTokens')] + return Result.TupleListResult(error=False, info='Success', result=participle_result) + else: + return Result.TupleListResult(error=True, info=result.info, result=[]) __all__ = [ diff --git a/omega_miya/utils/tencent_cloud_api/tmt.py b/omega_miya/utils/tencent_cloud_api/tmt.py new file mode 100644 index 00000000..67476588 --- /dev/null +++ b/omega_miya/utils/tencent_cloud_api/tmt.py @@ -0,0 +1,51 @@ +""" +@Author : Ailitonia +@Date : 2021/06/05 19:43 +@FileName : tmt.py +@Project : nonebot2_miya +@Description : 腾讯云机器翻译模块 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +from omega_miya.utils.Omega_Base import Result +from .cloud_api import SECRET_ID, SECRET_KEY, TencentCloudApi + + +class TencentTMT(object): + def __init__(self): + self.__secret_id = SECRET_ID + self.__secret_key = SECRET_KEY + + async def translate( + self, + source_text: str, + *, + source: str = 'auto', + target: str = 'zh', + project_id: int = 0) -> Result.DictResult: + payload = {'SourceText': source_text, 'Source': source, 'Target': target, 'ProjectId': project_id} + api = TencentCloudApi( + secret_id=self.__secret_id, + secret_key=self.__secret_key, + host='tmt.tencentcloudapi.com') + result = await api.post_request( + action='TextTranslate', version='2018-03-21', region='ap-chengdu', payload=payload) + + if result.error: + return result + response = dict(result.result.get('Response')) + if response.get('Error'): + return Result.DictResult(error=True, info=response.get('Error'), result={}) + + trans_result = { + 'source': response.get('Source'), + 'target': response.get('Target'), + 'targettext': response.get('TargetText') + } + return Result.DictResult(error=False, info='Success', result=trans_result) + + +__all__ = [ + 'TencentTMT' +] diff --git a/requirements.txt b/requirements.txt index ceab188b..71e3407d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,21 @@ nonebot2==2.0.0a13.post1 nonebot-adapter-cqhttp==2.0.0a13 -SQLAlchemy~=1.4.12 +SQLAlchemy~=1.4.22 mysqlclient~=2.0.3 aiomysql~=0.0.21 -aiocqhttp~=1.3.0 -aiofiles~=0.6.0 +aiocqhttp~=1.4.1 +aiofiles==0.6.0 bs4~=0.0.1 lxml~=4.6.3 -Pillow~=8.2.0 +numpy~=1.21.1 +Pillow~=8.3.1 +imageio~=2.9.0 beautifulsoup4~=4.9.3 -Jinja2~=2.11.3 aiohttp~=3.7.4.post0 xlwt~=1.3.0 ujson~=4.0.2 msgpack~=1.0.2 -pydantic~=1.8.1 +pydantic~=1.8.2 APScheduler~=3.7.0 pycryptodome~=3.10.1 -py7zr==0.15.2 \ No newline at end of file +py7zr==0.16.1 \ No newline at end of file diff --git a/test/init_all_auth_node.py b/test/init_all_auth_node.py index 5ad3df68..4130be7f 100644 --- a/test/init_all_auth_node.py +++ b/test/init_all_auth_node.py @@ -1,18 +1,22 @@ -from omega_miya.plugins.Omega_manage import init_group_auth_node, init_user_auth_node -from omega_miya.utils.Omega_Base import DBTable, DBFriend +from omega_miya.plugins.Omega_manager import init_group_auth_node, init_user_auth_node +from omega_miya.utils.Omega_Base import DBBot, DBFriend, DBBotGroup +from nonebot.adapters.cqhttp.bot import Bot import nonebot +driver = nonebot.get_driver() -async def init_all_auth_node(): - all_friends = await DBFriend.list_exist_friends() + +@driver.on_bot_connect +async def init_all_auth_node(bot: Bot): + self_bot = DBBot(self_qq=int(bot.self_id)) + all_friends = await DBFriend.list_exist_friends(self_bot=self_bot) for user_id in all_friends.result: if user_id in [123456789]: continue await init_user_auth_node(user_id=int(user_id)) print(f'Init_user_auth_node completed, user: {user_id}') - t = DBTable(table_name='Group') - group_res = await t.list_col('group_id') + group_res = await DBBotGroup.list_exist_bot_groups(self_bot=self_bot) all_groups = [int(x) for x in group_res.result] for group_id in all_groups: if group_id in [987654321]: @@ -20,5 +24,3 @@ async def init_all_auth_node(): await init_group_auth_node(group_id=int(group_id)) print(f'Init_group_auth_node completed, group: {group_id}') - -nonebot.get_driver().on_startup(init_all_auth_node) diff --git a/test/pixiv_illust_updater.py b/test/pixiv_illust_updater.py new file mode 100644 index 00000000..693c7df9 --- /dev/null +++ b/test/pixiv_illust_updater.py @@ -0,0 +1,144 @@ +""" +@Author : Ailitonia +@Date : 2021/06/19 1:24 +@FileName : pixiv_illust_updater.py +@Project : nonebot2_miya +@Description : 数据库pixiv illust作品图片链接信息更新工具 +@GitHub : https://github.com/Ailitonia +@Software : PyCharm +""" + +import asyncio +import json +from nonebot import on_command, logger +from nonebot.rule import to_me +from nonebot.permission import SUPERUSER +from nonebot.typing import T_State +from nonebot.adapters.cqhttp.bot import Bot +from nonebot.adapters.cqhttp.event import MessageEvent +from omega_miya.utils.Omega_Base import DBPixivillust, Result +from omega_miya.utils.pixiv_utils import PixivIllust + + +ONLY_UPDATE_NO_PAGES_ILLUST: bool = False + + +async def add_illust(pid: int, nsfw_tag: int) -> Result.IntResult: + illust_result = await PixivIllust(pid=pid).get_illust_data() + + if illust_result.success(): + illust_data = illust_result.result + title = illust_data.get('title') + uid = illust_data.get('uid') + uname = illust_data.get('uname') + url = illust_data.get('url') + tags = illust_data.get('tags') + is_r18 = illust_data.get('is_r18') + illust_pages = illust_data.get('illust_pages') + + if is_r18: + nsfw_tag = 2 + + illust = DBPixivillust(pid=pid) + illust_add_result = await illust.add(uid=uid, title=title, uname=uname, nsfw_tag=nsfw_tag, tags=tags, url=url) + if illust_add_result.error: + logger.error(f'Add illust failed: {illust_add_result.info}') + return Result.IntResult(error=True, info=illust_add_result.info, result=pid) + + for page, urls in illust_pages.items(): + original = urls.get('original') + regular = urls.get('regular') + small = urls.get('small') + thumb_mini = urls.get('thumb_mini') + page_upgrade_result = await illust.upgrade_page( + page=page, original=original, regular=regular, small=small, thumb_mini=thumb_mini) + if page_upgrade_result.error: + logger.warning(f'Upgrade illust page {page} failed: {page_upgrade_result.info}') + return illust_add_result + else: + return Result.IntResult(error=True, info=illust_result.info, result=pid) + + +async def output_nsfw(): + nsfw_tag = 2 + res = await DBPixivillust.list_all_illust_by_nsfw_tag(nsfw_tag=nsfw_tag) + dict_res = {x: nsfw_tag for x in res.result} + nsfw_json = f'C:\\nsfw_{nsfw_tag}.json' + with open(nsfw_json, 'w+') as f: + json.dump(dict_res, f) + + +async def reset_nsfw_tag(): + res = await DBPixivillust.reset_all_nsfw_tag() + print(res) + + +async def set_nsfw_tag(): + nsfw_tag = 2 + nsfw_json = f'C:\\nsfw_{nsfw_tag}.json' + with open(nsfw_json, 'r') as f: + tags = json.load(f) + res = await DBPixivillust.set_nsfw_tag(tags=tags) + print(res) + + +# 注册事件响应器 +pixiv_illust_page_updater = on_command('update_pages', rule=to_me(), permission=SUPERUSER, priority=10, block=True) + + +@pixiv_illust_page_updater.handle() +async def handle_first_receive(bot: Bot, event: MessageEvent, state: T_State): + illust_list = await DBPixivillust.list_all_illust() + if illust_list.error: + logger.error(f'Get illust list failed: {illust_list.info}') + await pixiv_illust_page_updater.finish('Get illust list failed.') + + all_illust = len(illust_list.result) + pid_list = [] + if ONLY_UPDATE_NO_PAGES_ILLUST: + for pid in illust_list.result: + pages = await DBPixivillust(pid=pid).get_all_page() + if pages.success() and not pages.result: + pid_list.append(pid) + else: + pid_list.extend(illust_list.result) + + # 导入操作 + all_count = len(pid_list) + success_count = 0 + failed_count = 0 + fail_list = [] + # 全部一起并发api撑不住, 做适当切分 + # 每个切片数量 + seg_n = 25 + pid_seg_list = [] + for i in range(0, all_count, seg_n): + pid_seg_list.append(pid_list[i:i + seg_n]) + # 每个切片打包一个任务 + seg_len = len(pid_seg_list) + process_rate = 0 + for seg_list in pid_seg_list: + tasks = [] + for pid in seg_list: + tasks.append(add_illust(pid=pid, nsfw_tag=0)) + # 进行异步处理 + _res = await asyncio.gather(*tasks) + # 对结果进行计数 + for item in _res: + if item.success(): + success_count += 1 + else: + failed_count += 1 + fail_list.append(item.result) + logger.error(f'upgrade illust {item.result} page failed: {item.info}') + # 显示进度 + process_rate += 1 + if process_rate % 10 == 0: + logger.info(f'Updater processing: {process_rate}/{seg_len}') + + logger.info(f'Updater: process complete,' + f'All illust: {all_illust},' + f'Total: {all_count},' + f'Success: {success_count},' + f'Failed: {failed_count},' + f'Fail List: {fail_list}')