Skip to content

Commit

Permalink
chore: Merge latest code
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Sep 9, 2024
2 parents cadc59b + 65c875d commit 2e99116
Show file tree
Hide file tree
Showing 18 changed files with 2,126 additions and 251 deletions.
2 changes: 1 addition & 1 deletion dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def initialize_components(
system_app.register(
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
)
system_app.register(DefaultScheduler)
system_app.register(DefaultScheduler, scheduler_enable=CFG.SCHEDULER_ENABLED)
system_app.register_instance(controller)
system_app.register(ConnectorManager)

Expand Down
18 changes: 14 additions & 4 deletions dbgpt/app/operators/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union, cast

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
self._max_token_limit = max_token_limit
self._sub_compose_dag = self._build_conversation_composer_dag()
self._sub_compose_dag: Optional[DAG] = None

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
conv_serve = ConversationServe.get_instance(self.system_app)
Expand Down Expand Up @@ -166,7 +166,7 @@ async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
"messages": history_messages,
"prompt_dict": prompt_dict,
}
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0])
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
model_request = ModelRequest.build_request(
Expand All @@ -184,6 +184,12 @@ async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
storage_conv.add_user_message(user_input)
return model_request

@property
def sub_compose_dag(self) -> DAG:
if not self._sub_compose_dag:
self._sub_compose_dag = self._build_conversation_composer_dag()
return self._sub_compose_dag

def _build_storage(
self, req: CommonLLMHttpRequestBody
) -> Tuple[StorageConversation, List[BaseMessage]]:
Expand All @@ -207,7 +213,11 @@ def _build_storage(
return storage_conv, history_messages

def _build_conversation_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
default_dag_variables = self.dag._default_dag_variables if self.dag else None
with DAG(
"dbgpt_awel_app_chat_history_prompt_composer",
default_dag_variables=default_dag_variables,
) as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# History transform task
if self._history_merge_mode == "token":
Expand Down
20 changes: 12 additions & 8 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ async def _resolve_variables(self, dag_ctx: DAGContext):
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
"""
from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder
from ...interface.variables import (
VariablesIdentifier,
VariablesPlaceHolder,
is_variable_string,
)

if not self._variables_provider:
return
Expand All @@ -432,11 +436,13 @@ async def _resolve_variables(self, dag_ctx: DAGContext):
resolve_items = []
for item in dag_ctx._dag_variables.items:
# TODO: Resolve variables just once?
if not item.value:
continue
if isinstance(item.value, str) and is_variable_string(item.value):
item.value = VariablesPlaceHolder(item.name, item.value)
if isinstance(item.value, VariablesPlaceHolder):
resolve_tasks.append(
self.blocking_func_to_async(
item.value.parse, self._variables_provider
)
item.value.async_parse(self._variables_provider)
)
resolve_items.append(item)
resolved_values = await asyncio.gather(*resolve_tasks)
Expand All @@ -462,15 +468,13 @@ async def _resolve_variables(self, dag_ctx: DAGContext):

if dag_provider:
# First try to resolve the variable with the DAG variables
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
dag_provider,
ignore_not_found_error=True,
default_identifier_map=default_identifier_map,
)
if resolved_value is None:
resolved_value = await self.blocking_func_to_async(
value.parse,
resolved_value = await value.async_parse(
self._variables_provider,
default_identifier_map=default_identifier_map,
)
Expand Down
8 changes: 6 additions & 2 deletions dbgpt/core/awel/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import pytest
import pytest_asyncio

from dbgpt.component import SystemApp

from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator


Expand Down Expand Up @@ -104,7 +106,9 @@ async def stream_input_nodes(request):

@asynccontextmanager
async def _create_variables(**kwargs):
vp = StorageVariablesProvider()
sys_app = SystemApp()
DAGVar.set_current_system_app(sys_app)
vp = StorageVariablesProvider(system_app=sys_app)
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():
Expand Down
71 changes: 68 additions & 3 deletions dbgpt/core/interface/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,15 @@ def get(
) -> Any:
"""Query variables from storage."""

async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
raise NotImplementedError("Current variables provider does not support async.")

@abstractmethod
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
Expand Down Expand Up @@ -457,6 +466,24 @@ def parse(
return None
raise e

async def async_parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables async."""
try:
return await variables_provider.async_get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e

def __repr__(self):
"""Return the representation of the variables place holder."""
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
Expand Down Expand Up @@ -508,6 +535,42 @@ def get(
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return self._convert_to_value_type(variable)

async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
# Try to get variables from storage
value = await blocking_func_to_async_no_executor(
self.get,
full_key,
default_value=None,
default_identifier_map=default_identifier_map,
)
if value is not None:
return value
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
# Get all builtin variables
variables = await self.async_get_variables(
key=key.key,
scope=key.scope,
scope_key=key.scope_key,
sys_code=key.sys_code,
user_name=key.user_name,
)
values = [v for v in variables if v.name == key.name]
if not values:
if default_value == _EMPTY_DEFAULT_VALUE:
raise ValueError(f"Variable {full_key} not found")
return default_value
if len(values) > 1:
raise ValueError(f"Multiple variables found for {full_key}")

variable = values[0]
return self._convert_to_value_type(variable)

def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
if variables_item.category == "secret":
Expand Down Expand Up @@ -577,9 +640,11 @@ async def async_get_variables(
)
if is_builtin:
return builtin_variables
executor_factory: Optional[
DefaultExecutorFactory
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
executor_factory: Optional[DefaultExecutorFactory] = None
if self.system_app:
executor_factory = DefaultExecutorFactory.get_instance(
self.system_app, default_component=None
)
if executor_factory:
return await blocking_func_to_async(
executor_factory.create(),
Expand Down
31 changes: 16 additions & 15 deletions dbgpt/serve/flow/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from sqlalchemy import URL

from dbgpt.component import SystemApp
from dbgpt.core.interface.variables import VariablesProvider
from dbgpt.core.interface.variables import (
FernetEncryption,
StorageVariablesProvider,
VariablesProvider,
)
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager

Expand Down Expand Up @@ -33,6 +37,7 @@ def __init__(
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):

if api_prefix is None:
api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"]
if api_tags is None:
Expand All @@ -41,8 +46,15 @@ def __init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
self._variables_provider: Optional[VariablesProvider] = None
self._serve_config: Optional[ServeConfig] = None
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._variables_provider: StorageVariablesProvider = StorageVariablesProvider(
storage=None,
encryption=FernetEncryption(self._serve_config.encrypt_key),
system_app=system_app,
)
system_app.register_instance(self._variables_provider)

def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
Expand All @@ -65,20 +77,13 @@ def on_init(self):

def before_start(self):
"""Called before the start of the application."""
from dbgpt.core.interface.variables import (
FernetEncryption,
StorageVariablesProvider,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer

from .models.models import ServeEntity, VariablesEntity
from .models.variables_adapter import VariablesAdapter

self._db_manager = self.create_or_get_db_manager()
self._serve_config = ServeConfig.from_app_config(
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
)

self._db_manager = self.create_or_get_db_manager()
storage_adapter = VariablesAdapter()
Expand All @@ -89,11 +94,7 @@ def before_start(self):
storage_adapter,
serializer,
)
self._variables_provider = StorageVariablesProvider(
storage=storage,
encryption=FernetEncryption(self._serve_config.encrypt_key),
system_app=self._system_app,
)
self._variables_provider.storage = storage

@property
def variables_provider(self):
Expand Down
Loading

0 comments on commit 2e99116

Please sign in to comment.