diff --git a/common/constants/llm.ts b/common/constants/llm.ts index 48da38fa..c6988d28 100644 --- a/common/constants/llm.ts +++ b/common/constants/llm.ts @@ -44,5 +44,6 @@ export const DEFAULT_USER_NAME = 'User'; export const TEXT2VEGA_INPUT_SIZE_LIMIT = 400; -export const TEXT2VEGA_AGENT_CONFIG_ID = 'os_text2vega'; +export const TEXT2VEGA_RULE_BASED_AGENT_CONFIG_ID = 'os_text2vega'; +export const TEXT2VEGA_WITH_INSTRUCTIONS_AGENT_CONFIG_ID = 'os_text2vega_with_instructions'; export const TEXT2PPL_AGENT_CONFIG_ID = 'os_query_assist_ppl'; diff --git a/public/components/visualization/source_selector.tsx b/public/components/visualization/source_selector.tsx index f053a074..90438a4e 100644 --- a/public/components/visualization/source_selector.tsx +++ b/public/components/visualization/source_selector.tsx @@ -14,7 +14,11 @@ import { DataSourceOption, } from '../../../../../src/plugins/data/public'; import { StartServices } from '../../types'; -import { TEXT2VEGA_AGENT_CONFIG_ID } from '../../../common/constants/llm'; +import { + TEXT2PPL_AGENT_CONFIG_ID, + TEXT2VEGA_RULE_BASED_AGENT_CONFIG_ID, + TEXT2VEGA_WITH_INSTRUCTIONS_AGENT_CONFIG_ID, +} from '../../../common/constants/llm'; import { getAssistantService } from '../../services'; const DEFAULT_DATA_SOURCE_TYPE = 'DEFAULT_INDEX_PATTERNS'; @@ -110,13 +114,20 @@ export const SourceSelector = ({ const assistantService = getAssistantService(); /** - * Check each data source to see if text to vega agent is configured or not + * Check each data source to see if text to vega agents are configured or not * If not configured, disable the corresponding index pattern from the selection list */ Object.keys(dataSourceIdToIndexPatternIds).forEach(async (key) => { - const res = await assistantService.client.agentConfigExists(TEXT2VEGA_AGENT_CONFIG_ID, { - dataSourceId: key !== 'DEFAULT' ? key : undefined, - }); + const res = await assistantService.client.agentConfigExists( + [ + TEXT2VEGA_RULE_BASED_AGENT_CONFIG_ID, + TEXT2VEGA_WITH_INSTRUCTIONS_AGENT_CONFIG_ID, + TEXT2PPL_AGENT_CONFIG_ID, + ], + { + dataSourceId: key !== 'DEFAULT' ? key : undefined, + } + ); if (!res.exists) { dataSourceIdToIndexPatternIds[key].forEach((indexPatternId) => { indexPatternOptions.options.forEach((option) => { diff --git a/public/services/assistant_client.ts b/public/services/assistant_client.ts index 624b86d8..1d151a9b 100644 --- a/public/services/assistant_client.ts +++ b/public/services/assistant_client.ts @@ -38,9 +38,10 @@ export class AssistantClient { }; /** - * Return if the given agent config name has agent id configured + * Check if the given agent config name has agent id configured + * Return false if any of the given config name has no agent id configured */ - agentConfigExists = (agentConfigName: string, options?: Options) => { + agentConfigExists = (agentConfigName: string | string[], options?: Options) => { return this.http.fetch<{ exists: boolean }>({ method: 'GET', path: AGENT_API.CONFIG_EXISTS, diff --git a/server/routes/agent_routes.ts b/server/routes/agent_routes.ts index abbbd087..227143a7 100644 --- a/server/routes/agent_routes.ts +++ b/server/routes/agent_routes.ts @@ -51,7 +51,7 @@ export function registerAgentRoutes(router: IRouter, assistantService: Assistant query: schema.oneOf([ schema.object({ dataSourceId: schema.maybe(schema.string()), - agentConfigName: schema.string(), + agentConfigName: schema.oneOf([schema.string(), schema.arrayOf(schema.string())]), }), ]), }, @@ -59,8 +59,12 @@ export function registerAgentRoutes(router: IRouter, assistantService: Assistant router.handleLegacyErrors(async (context, req, res) => { try { const assistantClient = assistantService.getScopedClient(req, context); - const agentId = await assistantClient.getAgentIdByConfigName(req.query.agentConfigName); - return res.ok({ body: { exists: Boolean(agentId) } }); + const promises = Array() + .concat(req.query.agentConfigName) + .map((configName) => assistantClient.getAgentIdByConfigName(configName)); + const results = await Promise.all(promises); + const exists = results.every((r) => Boolean(r)); + return res.ok({ body: { exists } }); } catch (e) { return res.ok({ body: { exists: false } }); } diff --git a/server/routes/text2viz_routes.ts b/server/routes/text2viz_routes.ts index bf6b9cf9..308808c5 100644 --- a/server/routes/text2viz_routes.ts +++ b/server/routes/text2viz_routes.ts @@ -7,8 +7,9 @@ import { schema } from '@osd/config-schema'; import { IRouter } from '../../../../src/core/server'; import { TEXT2PPL_AGENT_CONFIG_ID, - TEXT2VEGA_AGENT_CONFIG_ID, + TEXT2VEGA_RULE_BASED_AGENT_CONFIG_ID, TEXT2VEGA_INPUT_SIZE_LIMIT, + TEXT2VEGA_WITH_INSTRUCTIONS_AGENT_CONFIG_ID, TEXT2VIZ_API, } from '../../common/constants/llm'; import { AssistantServiceSetup } from '../services/assistant_service'; @@ -42,7 +43,10 @@ export function registerText2VizRoutes(router: IRouter, assistantService: Assist router.handleLegacyErrors(async (context, req, res) => { const assistantClient = assistantService.getScopedClient(req, context); try { - const response = await assistantClient.executeAgentByConfigName(TEXT2VEGA_AGENT_CONFIG_ID, { + const agentConfigName = req.body.input_instruction + ? TEXT2VEGA_WITH_INSTRUCTIONS_AGENT_CONFIG_ID + : TEXT2VEGA_RULE_BASED_AGENT_CONFIG_ID; + const response = await assistantClient.executeAgentByConfigName(agentConfigName, { input_question: req.body.input_question, input_instruction: req.body.input_instruction, ppl: req.body.ppl,