Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move STT and image generation to the assistant UI #42

Merged
merged 18 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions appinfo/routes.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
['name' => 'config#setConfig', 'url' => '/config', 'verb' => 'PUT'],
['name' => 'config#setAdminConfig', 'url' => '/admin-config', 'verb' => 'PUT'],

['name' => 'assistant#getTextProcessingTaskResultPage', 'url' => '/task/view/{taskId}', 'verb' => 'GET'],
['name' => 'assistant#getAssistantTaskResultPage', 'url' => '/task/view/{metaTaskId}', 'verb' => 'GET'],
['name' => 'assistant#getAssistantTask', 'url' => '/task/{metaTaskId}', 'verb' => 'GET'],
['name' => 'assistant#runTextProcessingTask', 'url' => '/task/run', 'verb' => 'POST'],
['name' => 'assistant#scheduleTextProcessingTask', 'url' => '/task/schedule', 'verb' => 'POST'],
['name' => 'assistant#runOrScheduleTextProcessingTask', 'url' => '/task/run-or-schedule', 'verb' => 'POST'],
['name' => 'assistant#getTextProcessingResult', 'url' => '/task/{taskId}', 'verb' => 'GET'],
['name' => 'assistant#parseTextFromFile', 'url' => '/parse-file', 'verb' => 'POST'],

['name' => 'Text2Image#processPrompt', 'url' => '/i/process_prompt', 'verb' => 'POST'],
Expand All @@ -27,7 +27,7 @@
['name' => 'FreePrompt#getOutputs', 'url' => '/f/get_outputs', 'verb' => 'GET'],
['name' => 'FreePrompt#cancelGeneration', 'url' => '/f/cancel_generation', 'verb' => 'POST'],

['name' => 'SpeechToText#getResultPage', 'url' => '/stt/resultPage', 'verb' => 'GET'],
['name' => 'SpeechToText#getResultPage', 'url' => '/stt/result-page/{metaTaskId}', 'verb' => 'GET'],
['name' => 'SpeechToText#transcribeAudio', 'url' => '/stt/transcribeAudio', 'verb' => 'POST'],
['name' => 'SpeechToText#transcribeFile', 'url' => '/stt/transcribeFile', 'verb' => 'POST'],
],
Expand Down
14 changes: 6 additions & 8 deletions lib/Controller/AssistantController.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ public function __construct(
}

/**
* @param int $taskId
* @param int $metaTaskId
* @return TemplateResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function getTextProcessingTaskResultPage(int $taskId): TemplateResponse {

public function getAssistantTaskResultPage(int $metaTaskId): TemplateResponse {
if ($this->userId !== null) {
$task = $this->assistantService->getTextProcessingTask($this->userId, $taskId);
$task = $this->assistantService->getAssistantTask($this->userId, $metaTaskId);
if ($task !== null) {
$this->initialStateService->provideInitialState('task', $task->jsonSerializeCc());
return new TemplateResponse(Application::APP_ID, 'taskResultPage');
Expand All @@ -44,14 +43,13 @@ public function getTextProcessingTaskResultPage(int $taskId): TemplateResponse {
}

/**
* @param int $taskId
* @param int $metaTaskId
* @return DataResponse
*/
#[NoAdminRequired]
public function getTextProcessingResult(int $taskId): DataResponse {

public function getAssistantTask(int $metaTaskId): DataResponse {
if ($this->userId !== null) {
$task = $this->assistantService->getTextProcessingTask($this->userId, $taskId);
$task = $this->assistantService->getAssistantTask($this->userId, $metaTaskId);
if ($task !== null) {
return new DataResponse([
'task' => $task->jsonSerializeCc(),
Expand Down
4 changes: 1 addition & 3 deletions lib/Controller/FreePromptController.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
use OCP\AppFramework\Http\Attribute\NoAdminRequired;
use OCP\AppFramework\Http\Attribute\NoCSRFRequired;
use OCP\AppFramework\Http\DataResponse;
use OCP\AppFramework\Services\IInitialState;

use OCP\IL10N;
use OCP\IRequest;
Expand All @@ -23,7 +22,6 @@ public function __construct(
IRequest $request,
private FreePromptService $freePromptService,
private ?string $userId,
private IInitialState $initialStateService,
private IL10N $l10n,
) {
parent::__construct($appName, $request);
Expand All @@ -46,7 +44,7 @@ public function processPrompt(string $prompt): DataResponse {
} catch (Exception $e) {
return new DataResponse(['error' => $e->getMessage()], (int)$e->getCode());
}

return new DataResponse($result);
}

Expand Down
8 changes: 5 additions & 3 deletions lib/Controller/SpeechToTextController.php
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ public function __construct(
}

/**
* @param int $id
* @param int $metaTaskId
* @return TemplateResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function getResultPage(int $id): TemplateResponse {
public function getResultPage(int $metaTaskId): TemplateResponse {
$response = new TemplateResponse(Application::APP_ID, 'speechToTextResultPage');
try {
$initData = [
'task' => $this->internalGetTask($id),
'task' => $this->internalGetTask($metaTaskId),
];
} catch (Exception $e) {
$initData = [
Expand Down Expand Up @@ -102,6 +102,7 @@ public function getTranscript(int $id): DataResponse {
*
* @param integer $id
* @return MetaTask
* @throws Exception
*/
private function internalGetTask(int $id): MetaTask {
try {
Expand All @@ -128,6 +129,7 @@ private function internalGetTask(int $id): MetaTask {

/**
* @return DataResponse
* @throws NotPermittedException
*/
#[NoAdminRequired]
public function transcribeAudio(): DataResponse {
Expand Down
29 changes: 20 additions & 9 deletions lib/Controller/Text2ImageController.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use OCP\AppFramework\Services\IInitialState;
use OCP\Db\Exception as DbException;

use OCP\Files\NotPermittedException;
use OCP\IL10N;
use OCP\IRequest;
use OCP\TextToImage\Exception\TaskFailureException;
Expand All @@ -38,17 +39,26 @@ public function __construct(
}

/**
* @param string $appId
* @param string $identifier
* @param string $prompt
* @param int $nResults
* @param bool $displayPrompt
* @param bool $notifyReadyIfScheduled
* @param bool $schedule
* @return DataResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function processPrompt(string $prompt, int $nResults = 1, bool $displayPrompt = false): DataResponse {
public function processPrompt(
string $appId, string $identifier, string $prompt, int $nResults = 1, bool $displayPrompt = false,
bool $notifyReadyIfScheduled = false, bool $schedule = false
): DataResponse {
$nResults = min(10, max(1, $nResults));
try {
$result = $this->text2ImageHelperService->processPrompt($prompt, $nResults, $displayPrompt, $this->userId);
$result = $this->text2ImageHelperService->processPrompt(
$appId, $identifier, $prompt, $nResults, $displayPrompt, $this->userId, $notifyReadyIfScheduled, $schedule
);
} catch (Exception | TaskFailureException $e) {
return new DataResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}
Expand Down Expand Up @@ -92,7 +102,7 @@ public function getImage(string $imageGenId, int $fileNameId): DataDisplayRespon
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if ($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'fileId' => $fileNameId, 'status' => $e->getCode()]);
}
return $response;
}
Expand Down Expand Up @@ -125,7 +135,7 @@ public function getGenerationInfo(string $imageGenId): DataResponse {
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if ($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'status' => $e->getCode()]);
}
return $response;
}
Expand All @@ -136,12 +146,12 @@ public function getGenerationInfo(string $imageGenId): DataResponse {
/**
* @param string $imageGenId
* @param array $fileVisStatusArray
* @return DataResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
#[BruteForceProtection(action: 'imageGenId')]
public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStatusArray): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to set visibility of image files; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -156,7 +166,7 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'status' => $e->getCode()]);
}
return $response;
}
Expand All @@ -175,7 +185,6 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
#[NoCSRFRequired]
#[AnonRateLimit(limit: 10, period: 60)]
public function notifyWhenReady(string $imageGenId): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to notify when ready; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -187,6 +196,7 @@ public function notifyWhenReady(string $imageGenId): DataResponse {
}
return new DataResponse('success', Http::STATUS_OK);
}

/**
* Cancel image generation
*
Expand All @@ -196,12 +206,12 @@ public function notifyWhenReady(string $imageGenId): DataResponse {
*
* @param string $imageGenId
* @return DataResponse
* @throws NotPermittedException
*/
#[NoAdminRequired]
#[NoCSRFRequired]
#[AnonRateLimit(limit: 10, period: 60)]
public function cancelGeneration(string $imageGenId): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to cancel generation; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -216,6 +226,7 @@ public function cancelGeneration(string $imageGenId): DataResponse {
* Does not need bruteforce protection
*
* @param string|null $imageGenId
* @param bool|null $forceEditMode
* @return TemplateResponse
*/
#[NoAdminRequired]
Expand All @@ -226,7 +237,7 @@ public function showGenerationPage(?string $imageGenId, ?bool $forceEditMode = f
$forceEditMode = false;
}
$this->initialStateService->provideInitialState('generation-page-inputs', ['image_gen_id' => $imageGenId, 'force_edit_mode' => $forceEditMode]);

return new TemplateResponse(Application::APP_ID, 'imageGenerationPage');
}
}
16 changes: 5 additions & 11 deletions lib/Db/Text2Image/ImageGeneration.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
* @method \void setImageGenId(string $imageGenId)
* @method \string getPrompt()
* @method \void setPrompt(string $prompt)
* @method \void setUserId(string $userId)
* @method \string getUserId()
* @method \void setTimestamp(int $timestamp)
* @method \void setUserId(string $userId)
* @method \int getTimestamp()
* @method \void setExpGenTime(int $expGenTime)
* @method \void setTimestamp(int $timestamp)
* @method \boolean getNotifyReady()
* @method \void setNotifyReady(bool $notifyReady)
* @method \int getExpGenTime()
* @method \void setExpGenTime(int $expGenTime)
*
*/
class ImageGeneration extends Entity implements \JsonSerializable {
Expand Down Expand Up @@ -80,12 +82,4 @@ public function setFailed(?bool $failed): void {
public function getFailed(): bool {
return $this->failed === true;
}

public function setNotifyReady(?bool $notifyReady): void {
$this->notifyReady = $notifyReady === true;
}

public function getNotifyReady(): bool {
return $this->notifyReady === true;
}
}
8 changes: 6 additions & 2 deletions lib/Db/Text2Image/ImageGenerationMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,22 @@ public function getImageGenerationOfImageGenId(string $imageGenId): ImageGenerat
* @param string $prompt
* @param string $userId
* @param int|null $expCompletionTime
* @param bool $notifyReady
* @return ImageGeneration
* @throws Exception
*/
public function createImageGeneration(string $imageGenId, string $prompt = '', string $userId = '', ?int $expCompletionTime = null): ImageGeneration {
public function createImageGeneration(
string $imageGenId, string $prompt = '', string $userId = '', ?int $expCompletionTime = null,
bool $notifyReady = false
): ImageGeneration {
$imageGeneration = new ImageGeneration();
$imageGeneration->setImageGenId($imageGenId);
$imageGeneration->setTimestamp((new DateTime())->getTimestamp());
$imageGeneration->setPrompt($prompt);
$imageGeneration->setUserId($userId);
$imageGeneration->setIsGenerated(false);
$imageGeneration->setFailed(false);
$imageGeneration->setNotifyReady(false);
$imageGeneration->setNotifyReady($notifyReady);
$imageGeneration->setExpGenTime($expCompletionTime ?? (new DateTime())->getTimestamp());
return $this->insert($imageGeneration);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Notification/Notifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public function prepare(INotification $notification, string $languageCode): INot
}


$link = $params['target'] ?? $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getTextProcessingTaskResultPage', ['taskId' => $params['id']]);
$link = $params['target'] ?? $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getAssistantTaskResultPage', ['metaTaskId' => $params['id']]);
$iconUrl = $this->url->getAbsoluteURL($this->url->imagePath('core', 'actions/error.svg'));

$notification
Expand Down
12 changes: 6 additions & 6 deletions lib/Reference/Text2ImageReferenceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

use Exception;
use OCA\TpAssistant\AppInfo\Application;
use OCA\TpAssistant\Db\Text2Image\ImageGeneration;
use OCA\TpAssistant\Db\Text2Image\ImageGenerationMapper;
use OCP\Collaboration\Reference\ADiscoverableReferenceProvider;
use OCP\Collaboration\Reference\IReference;
Expand Down Expand Up @@ -74,7 +73,6 @@ public function resolveReference(string $referenceText): ?IReference {
}

try {
/** @var ImageGeneration $imageGeneration */
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($imageGenId);
} catch (Exception $e) {
$imageGeneration = null;
Expand All @@ -89,14 +87,16 @@ public function resolveReference(string $referenceText): ?IReference {
$reference = new Reference($referenceText);
$imageUrl = $this->urlGenerator->linkToRouteAbsolute(
Application::APP_ID . '.Text2Image.getGenerationInfo',
[
'imageGenId' => $imageGenId,
]
['imageGenId' => $imageGenId]
);

$reference->setImageUrl($imageUrl);

$richObjectInfo = ['prompt' => $prompt, 'proxied_url' => $imageUrl];
$richObjectInfo = [
'prompt' => $prompt,
'proxied_url' => $imageUrl,
'imageGenId' => $imageGenId,
];
$reference->setRichObject(
self::RICH_OBJECT_TYPE,
$richObjectInfo,
Expand Down
28 changes: 9 additions & 19 deletions lib/Service/AssistantService.php
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,7 @@ public function sendNotification(MetaTask $task, ?string $customTarget = null, ?
}

private function getDefaultTarget(MetaTask $task): string {
$category = $task->getCategory();
if ($category === Application::TASK_CATEGORY_TEXT_GEN) {
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getTextProcessingTaskResultPage', ['taskId' => $task->getId()]);
} elseif ($category === Application::TASK_CATEGORY_SPEECH_TO_TEXT) {
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.SpeechToText.getResultPage', ['id' => $task->getId()]);
} elseif ($category === Application::TASK_CATEGORY_TEXT_TO_IMAGE) {
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($task->getIdentifier());
return $this->url->linkToRouteAbsolute(
Application::APP_ID . '.Text2Image.showGenerationPage',
[
'imageGenId' => $imageGeneration->getImageGenId(),
]
);
}
return '';
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getAssistantTaskResultPage', ['metaTaskId' => $task->getId()]);
}

/**
Expand Down Expand Up @@ -168,19 +154,23 @@ private function sanitizeInputs(string $type, array $inputs): array {

/**
* @param string $userId
* @param int $taskId
* @param int $metaTaskId
* @return MetaTask|null
*/
public function getTextProcessingTask(string $userId, int $taskId): ?MetaTask {
public function getAssistantTask(string $userId, int $metaTaskId): ?MetaTask {
try {
$metaTask = $this->metaTaskMapper->getMetaTask($taskId);
$metaTask = $this->metaTaskMapper->getMetaTask($metaTaskId);
} catch (DoesNotExistException | MultipleObjectsReturnedException | \OCP\Db\Exception $e) {
return null;
}
if ($metaTask->getUserId() !== $userId) {
return null;
}
// Check if the task status is up-to-date (if not, update status and output)
// only try to update meta task status for text processing ones
if ($metaTask->getCategory() !== Application::TASK_CATEGORY_TEXT_GEN) {
return $metaTask;
}
// Check if the text processing task status is up-to-date (if not, update status and output)
try {
$ocpTask = $this->textProcessingManager->getTask($metaTask->getOcpTaskId());

Expand Down
Loading
Loading