Skip to content

Commit

Permalink
✨ 增加回避策略
Browse files Browse the repository at this point in the history
  • Loading branch information
AzideCupric committed Jun 20, 2024
1 parent e747c5b commit 68d5a49
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 49 deletions.
39 changes: 2 additions & 37 deletions nonebot_bison/platform/bilibili/platforms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import re
import json
from copy import deepcopy
from functools import wraps
from enum import Enum, unique
from typing import NamedTuple
from typing_extensions import Self
from typing import TypeVar, NamedTuple
from collections.abc import Callable, Awaitable

from yarl import URL
from nonebot import logger
from httpx import AsyncClient
from httpx import URL as HttpxURL
from pydantic import Field, BaseModel, ValidationError
from nonebot.compat import type_validate_json, type_validate_python

Expand All @@ -19,8 +16,8 @@
from nonebot_bison.utils import text_similarity, decode_unicode_escapes
from nonebot_bison.types import Tag, Target, RawPost, ApiError, Category

from .scheduler import BilibiliSite, BililiveSite, BiliBangumiSite
from ..platform import NewMessage, StatusChange, CategoryNotSupport, CategoryNotRecognize
from .scheduler import BilibiliSite, BililiveSite, ApiCode352Error, BiliBangumiSite, retry_for_352
from .models import (
PostAPI,
UserAPI,
Expand All @@ -38,38 +35,6 @@
LiveRecommendMajor,
)

B = TypeVar("B", bound="Bilibili")
MAX_352_RETRY_COUNT = 3


class ApiCode352Error(Exception):
def __init__(self, url: HttpxURL) -> None:
msg = f"api {url} error"
super().__init__(msg)


def retry_for_352(func: Callable[[B, Target], Awaitable[list[DynRawPost]]]):
retried_times = 0

@wraps(func)
async def wrapper(bls: B, *args, **kwargs):
nonlocal retried_times
try:
res = await func(bls, *args, **kwargs)
except ApiCode352Error as e:
if retried_times < MAX_352_RETRY_COUNT:
retried_times += 1
logger.warning(f"获取动态列表失败,尝试刷新cookie: {retried_times}/{MAX_352_RETRY_COUNT}")
await bls.ctx.refresh_client()
return [] # 返回空列表
else:
raise ApiError(e.args[0])
else:
retried_times = 0
return res

return wrapper


class _ProcessedText(NamedTuple):
title: str
Expand Down
131 changes: 130 additions & 1 deletion nonebot_bison/platform/bilibili/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,145 @@
from enum import Enum
from random import randint
from functools import wraps
from typing_extensions import override
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, TypeVar
from collections.abc import Callable, Awaitable

from httpx import AsyncClient
from httpx import URL as HttpxURL
from nonebot import logger, require
from playwright.async_api import Cookie

from nonebot_bison.types import Target
from nonebot_bison.types import Target, ApiError
from nonebot_bison.utils import Site, ClientManager, http_client

from .models import DynRawPost

if TYPE_CHECKING:
from .platforms import Bilibili

require("nonebot_plugin_htmlrender")
from nonebot_plugin_htmlrender import get_browser

B = TypeVar("B", bound="Bilibili")


class ApiCode352Error(Exception):
def __init__(self, url: HttpxURL) -> None:
msg = f"api {url} error"
super().__init__(msg)


class ScheduleLevel(Enum):
NORMAL = 0
REFRESH = 1
BACKOFF = 2
RAISE = 3

@staticmethod
def level_up(level: "ScheduleLevel"):
if level.value == ScheduleLevel.RAISE.value:
return ScheduleLevel.RAISE
return ScheduleLevel(level.value + 1)


class ScheduleState:
MAX_REFRESH_COUNT = 3
MAX_BACKOFF_COUNT = 2
BACKOFF_TIMEDELTA = timedelta(minutes=5)

current_times: int
level: ScheduleLevel
backoff_start_time: datetime
latest_normal_return: list[DynRawPost] = []

def __init__(self):
self.current_times = 1
self.level = ScheduleLevel.NORMAL

def increase(self):
match self.level:
case ScheduleLevel.NORMAL:
logger.warning("获取动态列表失败,进入刷新cookie状态")
self.level = ScheduleLevel.level_up(self.level)
case ScheduleLevel.REFRESH if self.current_times < self.MAX_REFRESH_COUNT:
self.current_times += 1
logger.warning(
f"本次刷新后获取动态列表失败,下次请求前再次尝试刷新cookie:{self.current_times}/{self.MAX_REFRESH_COUNT}"
)
case ScheduleLevel.REFRESH if self.current_times >= self.MAX_REFRESH_COUNT:
self.current_times = 1
logger.trace(f"set backoff_start_time: {datetime.now()} in REFRESH level up")
logger.warning("刷新cookie失败,进行回避状态")
self.backoff_start_time = datetime.now()
self.level = ScheduleLevel.level_up(self.level)
case ScheduleLevel.BACKOFF if self.current_times < self.MAX_BACKOFF_COUNT:
self.current_times += 1
logger.trace(f"set backoff_start_time: {datetime.now()} in BACKOFF level")
logger.warning(
f"上次回避尝试失败,等待下次尝试({self.current_times}/{self.MAX_BACKOFF_COUNT}),预计等待 "
f"{self.BACKOFF_TIMEDELTA * self.current_times ** 2}"
)
self.backoff_start_time = datetime.now()
case ScheduleLevel.BACKOFF if self.current_times >= self.MAX_BACKOFF_COUNT:
self.current_times = 1
self.level = ScheduleLevel.level_up(self.level)
logger.error("获取动态列表失败,尝试刷新cookie失败,尝试回避失败,放弃")
case ScheduleLevel.RAISE:
logger.critical("Already in RAISE level")
case _:
raise ValueError("ScheduleLevel Error")

def reset(self):
self.current_times = 1
self.level = ScheduleLevel.NORMAL

def is_in_backoff_time(self):
"""是否在指数回避时间内"""
# 指数回避
logger.trace(f"current_times: {self.current_times}")
logger.trace(f"backoff_start_time: {self.backoff_start_time}")
logger.trace(f"now: {datetime.now()}")
logger.trace(f"delta: {datetime.now() - self.backoff_start_time}")
logger.trace(f"wait: {self.BACKOFF_TIMEDELTA * self.current_times ** 2}")

res = datetime.now() - self.backoff_start_time < self.BACKOFF_TIMEDELTA * self.current_times**2

logger.trace(f"result: {res}")
return res


def retry_for_352(func: Callable[[B, Target], Awaitable[list[DynRawPost]]]):
schedule_state = ScheduleState()

@wraps(func)
async def wrapper(bls: B, *args, **kwargs):
nonlocal schedule_state
try:
match schedule_state.level:
case ScheduleLevel.BACKOFF if schedule_state.is_in_backoff_time():
logger.warning("回避中,本次不进行请求")
return schedule_state.latest_normal_return
case ScheduleLevel.BACKOFF | ScheduleLevel.REFRESH:
logger.debug("尝试刷新客户端")
await bls.ctx.refresh_client()

res = await func(bls, *args, **kwargs)
except ApiCode352Error as e:
schedule_state.increase()
match schedule_state.level:
case ScheduleLevel.RAISE:
raise ApiError(e.args[0]) from e
case _:
return schedule_state.latest_normal_return
else:
schedule_state.reset()
schedule_state.latest_normal_return = res
return res

return wrapper


class BilibiliClientManager(ClientManager):
_client: AsyncClient
Expand Down
21 changes: 20 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pytest-cov = ">=3,<6"
pytest-mock = "^3.10.0"
pytest-xdist = { extras = ["psutil"], version = "^3.1.0" }
respx = ">=0.20,<0.22"
freezegun = "^1.5.1"

[tool.poetry.group.docker]
optional = true
Expand Down
67 changes: 57 additions & 10 deletions tests/platforms/test_bilibili.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import respx
import pytest
from loguru import logger
from nonebug.app import App
from httpx import URL, Response
from freezegun import freeze_time
from nonebot.compat import model_dump, type_validate_python

from .utils import get_json
Expand Down Expand Up @@ -59,8 +61,9 @@ async def test_retry_for_352(app: App):
from nonebot_bison.post import Post
from nonebot_bison.platform.platform import NewMessage
from nonebot_bison.types import Target, RawPost, ApiError
from nonebot_bison.platform.bilibili.scheduler import ScheduleState
from nonebot_bison.utils import ClientManager, ProcessContext, http_client
from nonebot_bison.platform.bilibili.platforms import MAX_352_RETRY_COUNT, ApiCode352Error, retry_for_352
from nonebot_bison.platform.bilibili.platforms import ApiCode352Error, retry_for_352

now = time()
raw_post_1 = {"id": 1, "text": "p1", "date": now, "tags": ["tag1"], "category": 1}
Expand Down Expand Up @@ -118,18 +121,22 @@ class MockClientManager(ClientManager):
refresh_client_call_count = 0

async def get_client(self, target: Target | None):
logger.debug(f"call get_client: {target}, {datetime.now()}")
self.get_client_call_count += 1
return http_client()

async def get_client_for_static(self):
logger.debug(f"call get_client_for_static: {datetime.now()}")
self.get_client_for_static_call_count += 1
return http_client()

async def get_query_name_client(self):
logger.debug(f"call get_query_name_client: {datetime.now()}")
self.get_query_name_client_call_count += 1
return http_client()

async def refresh_client(self):
logger.debug(f"call refresh_client: {datetime.now()}")
self.refresh_client_call_count += 1

fakebili = MockPlatform(ProcessContext(MockClientManager()))
Expand All @@ -154,16 +161,56 @@ async def refresh_client(self):
assert client_mgr.get_client_call_count == 2
assert client_mgr.refresh_client_call_count == 0

# 有异常
freeze_start = datetime(2024, 6, 19, 0, 0, 0, 0)
fakebili.set_raise352(True)
for i in range(MAX_352_RETRY_COUNT):
res1: list[dict[str, Any]] = await fakebili.get_sub_list(Target("1")) # type: ignore
assert len(res1) == 0
assert client_mgr.get_client_call_count == 3 + i
assert client_mgr.refresh_client_call_count == i + 1
# 超过最大重试次数,抛出异常
with pytest.raises(ApiError):
await fakebili.get_sub_list(Target("1"))
# 有异常
with freeze_time(freeze_start):
for i in range(1, ScheduleState.MAX_REFRESH_COUNT + 1):
logger.debug(f"refresh count: {i}, {datetime.now()}")
res1: list[dict[str, Any]] = await fakebili.get_sub_list(Target("1")) # type: ignore
assert len(res1) == 2 # 上次正常返回的结果
assert client_mgr.get_client_call_count == 2 + i
assert client_mgr.refresh_client_call_count == i - 1

# 本次为最后一次重试失败的请求
logger.debug(f"latest refresh: {datetime.now()}")
res2: list[dict[str, Any]] = await fakebili.get_sub_list(Target("1")) # type: ignore
assert len(res2) == 2 # 上次正常返回的结果
assert client_mgr.get_client_call_count == 2 + ScheduleState.MAX_REFRESH_COUNT + 1
assert client_mgr.refresh_client_call_count == ScheduleState.MAX_REFRESH_COUNT

assert client_mgr.get_client_call_count == 2 + ScheduleState.MAX_REFRESH_COUNT + 1
assert client_mgr.refresh_client_call_count == ScheduleState.MAX_REFRESH_COUNT

# 超过最大重试次数,进入回避

# 在回避时间内,不进行请求
with freeze_time(freeze_start + ScheduleState.BACKOFF_TIMEDELTA / 2):
logger.debug(f"in backoff time, {datetime.now()}")
res3: list[dict[str, Any]] = await fakebili.get_sub_list(Target("1")) # type: ignore
assert len(res3) == 2 # 上次正常返回的结果
assert client_mgr.get_client_call_count == 2 + ScheduleState.MAX_REFRESH_COUNT + 1
assert client_mgr.refresh_client_call_count == ScheduleState.MAX_REFRESH_COUNT

# 进行回避尝试
for i in range(1, ScheduleState.MAX_BACKOFF_COUNT + 1):
new_freeze_start = freeze_start + ScheduleState.BACKOFF_TIMEDELTA * i**2
with freeze_time(new_freeze_start):
logger.debug(f"backoff count: {i}, {datetime.now()}")
# 如果是最后一次回避尝试,则应该抛出异常
if i == ScheduleState.MAX_BACKOFF_COUNT:
with pytest.raises(ApiError):
await fakebili.get_sub_list(Target("1"))
continue
res2: list[dict[str, Any]] = await fakebili.get_sub_list(Target("1")) # type: ignore
assert len(res2) == 2 # 上次正常返回的结果
assert client_mgr.get_client_call_count == 3 + ScheduleState.MAX_REFRESH_COUNT + i
assert client_mgr.refresh_client_call_count == ScheduleState.MAX_REFRESH_COUNT + i

freeze_start = new_freeze_start

assert client_mgr.get_client_call_count == 3 + ScheduleState.MAX_REFRESH_COUNT + ScheduleState.MAX_BACKOFF_COUNT
assert client_mgr.refresh_client_call_count == ScheduleState.MAX_REFRESH_COUNT + ScheduleState.MAX_BACKOFF_COUNT


@pytest.mark.asyncio
Expand Down

0 comments on commit 68d5a49

Please sign in to comment.