Skip to content

Commit

Permalink
Don't rely on injected userId in services
Browse files Browse the repository at this point in the history
Signed-off-by: MB-Finski <[email protected]>
  • Loading branch information
MB-Finski committed Jan 30, 2024
1 parent 292bbb0 commit dd2b2f1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 36 deletions.
2 changes: 1 addition & 1 deletion lib/Controller/AssistantController.php
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public function runOrScheduleTextProcessingTask(string $type, array $inputs, str
#[NoAdminRequired]
public function parseTextFromFile(string $filePath): DataResponse {
try {
$text = $this->assistantService->parseTextFromFile($filePath);
$text = $this->assistantService->parseTextFromFile($filePath, $this->userId);
} catch (\Exception | \Throwable $e) {
return new DataResponse($e->getMessage(), Http::STATUS_BAD_REQUEST);
}
Expand Down
20 changes: 8 additions & 12 deletions lib/Controller/Text2ImageController.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public function __construct(
public function processPrompt(string $prompt, int $nResults = 1, bool $displayPrompt = false): DataResponse {
$nResults = min(10, max(1, $nResults));
try {
$result = $this->text2ImageHelperService->processPrompt($prompt, $nResults, $displayPrompt);
$result = $this->text2ImageHelperService->processPrompt($prompt, $nResults, $displayPrompt, $this->userId);
} catch (Exception | TaskFailureException $e) {
return new DataResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}
Expand All @@ -61,7 +61,7 @@ public function processPrompt(string $prompt, int $nResults = 1, bool $displayPr
#[NoCSRFRequired]
public function getPromptHistory(): DataResponse {
try {
$response = $this->text2ImageHelperService->getPromptHistory();
$response = $this->text2ImageHelperService->getPromptHistory($this->userId);
} catch (DbException $e) {
return new DataResponse(['error' => 'Unknown error while retrieving prompt history.'], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand Down Expand Up @@ -113,7 +113,7 @@ public function getImage(string $imageGenId, int $fileNameId): DataDisplayRespon
#[BruteForceProtection(action: 'imageGenId')]
public function getGenerationInfo(string $imageGenId): DataResponse {
try {
$result = $this->text2ImageHelperService->getGenerationInfo($imageGenId, true);
$result = $this->text2ImageHelperService->getGenerationInfo($imageGenId, $this->userId, true);
} catch (Exception $e) {
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if ($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
Expand All @@ -139,7 +139,7 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
}

try {
$this->text2ImageHelperService->setVisibilityOfImageFiles($imageGenId, $fileVisStatusArray);
$this->text2ImageHelperService->setVisibilityOfImageFiles($imageGenId, $fileVisStatusArray, $this->userId);
} catch (Exception $e) {
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
Expand All @@ -164,7 +164,7 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
#[AnonRateLimit(limit: 10, period: 60)]
public function notifyWhenReady(string $imageGenId): DataResponse {
try {
$this->text2ImageHelperService->notifyWhenReady($imageGenId);
$this->text2ImageHelperService->notifyWhenReady($imageGenId, $this->userId);
} catch (Exception $e) {
// Ignore
}
Expand All @@ -184,7 +184,7 @@ public function notifyWhenReady(string $imageGenId): DataResponse {
#[NoCSRFRequired]
#[AnonRateLimit(limit: 10, period: 60)]
public function cancelGeneration(string $imageGenId): DataResponse {
$this->text2ImageHelperService->cancelGeneration($imageGenId);
$this->text2ImageHelperService->cancelGeneration($imageGenId, $this->userId);
return new DataResponse('success', Http::STATUS_OK);
}

Expand All @@ -203,12 +203,8 @@ public function showGenerationPage(?string $imageGenId, ?bool $forceEditMode = f
if ($forceEditMode === null) {
$forceEditMode = false;
}
if ($imageGenId === null) {
$this->initialStateService->provideInitialState('generation-page-inputs', ['image_gen_id' => $imageGenId, 'force_edit_mode' => $forceEditMode]);
} else {
$this->initialStateService->provideInitialState('generation-page-inputs', ['image_gen_id' => $imageGenId, 'force_edit_mode' => $forceEditMode]);
}

$this->initialStateService->provideInitialState('generation-page-inputs', ['image_gen_id' => $imageGenId, 'force_edit_mode' => $forceEditMode]);

return new TemplateResponse(Application::APP_ID, 'imageGenerationPage');
}
}
6 changes: 3 additions & 3 deletions lib/Service/AssistantService.php
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public function __construct(
private TaskMapper $taskMapper,
private LoggerInterface $logger,
private IRootFolder $storage,
private ?string $userId,
) {
}

Expand Down Expand Up @@ -286,12 +285,13 @@ public function runOrScheduleTextProcessingTask(string $type, array $inputs, str
/**
* Parse text from file (if parsing the file type is supported)
* @param string $filePath
* @param string $userId
* @return string
* @throws \Exception
*/
public function parseTextFromFile(string $filePath): string {
public function parseTextFromFile(string $filePath, string $userId): string {
try {
$userFolder = $this->storage->getUserFolder($this->userId);
$userFolder = $this->storage->getUserFolder($userId);
} catch (\OC\User\NoUserException | NotPermittedException $e) {
throw new \Exception('Could not access user storage.');
}
Expand Down
45 changes: 25 additions & 20 deletions lib/Service/Text2Image/Text2ImageHelperService.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public function __construct(
private IL10N $l10n,
private AssistantService $assistantService,
private TaskMapper $taskMapper,
private ?string $userId,
) {
}

Expand All @@ -64,13 +63,14 @@ public function __construct(
* @param string $prompt
* @param int $nResults
* @param bool $displayPrompt
* @param string $userId
* @return array
* @throws Exception
* @throws PreConditionNotMetException
* @throws TaskFailureException ;
* @throws RandomException
*/
public function processPrompt(string $prompt, int $nResults, bool $displayPrompt): array {
public function processPrompt(string $prompt, int $nResults, bool $displayPrompt, string $userId): array {
if (!$this->textToImageManager->hasProviders()) {
$this->logger->error('No text to image processing provider available');
throw new BaseException($this->l10n->t('No text to image processing provider available'));
Expand All @@ -83,7 +83,7 @@ public function processPrompt(string $prompt, int $nResults, bool $displayPrompt
$imageGenId = bin2hex(random_bytes(16));
}

$promptTask = new Task($prompt, Application::APP_ID, $nResults, $this->userId, $imageGenId);
$promptTask = new Task($prompt, Application::APP_ID, $nResults, $userId, $imageGenId);

$this->textToImageManager->runOrScheduleTask($promptTask);

Expand All @@ -101,12 +101,12 @@ public function processPrompt(string $prompt, int $nResults, bool $displayPrompt
}

// Store the image id to the db:
$this->imageGenerationMapper->createImageGeneration($imageGenId, $displayPrompt ? $prompt : '', $this->userId ?? '', $expCompletionTime->getTimestamp());
$this->imageGenerationMapper->createImageGeneration($imageGenId, $displayPrompt ? $prompt : '', $userId, $expCompletionTime->getTimestamp());

// Create an assistant meta task for the image generation task:
// TODO check if we should create a task if userId is null
$this->taskMapper->createTask(
$this->userId ?? '',
$userId,
['prompt' => $prompt],
$imageGenId,
time(),
Expand Down Expand Up @@ -137,8 +137,8 @@ public function processPrompt(string $prompt, int $nResults, bool $displayPrompt
);

// Save the prompt to database
if ($this->userId !== null) {
$this->promptMapper->createPrompt($this->userId, $prompt);
if ($userId !== null) {
$this->promptMapper->createPrompt($userId, $prompt);
}

return ['url' => $infoUrl, 'reference_url' => $referenceUrl, 'image_gen_id' => $imageGenId, 'prompt' => $prompt];
Expand Down Expand Up @@ -172,14 +172,15 @@ private function genIdExists(string $imageGenId): bool {
}

/**
* @param string $userId
* @return array
* @throws \OCP\DB\Exception
*/
public function getPromptHistory(): array {
if ($this->userId === null) {
public function getPromptHistory(string $userId): array {
if ($userId === null) {
return [];
} else {
return $this->promptMapper->getPromptsOfUser($this->userId);
return $this->promptMapper->getPromptsOfUser($userId);
}
}

Expand Down Expand Up @@ -286,11 +287,12 @@ public function getImageDataFolder(): ISimpleFolder {
* Get image generation info
*
* @param string $imageGenId
* @param string $userId
* @param bool $updateTimestamp
* @return array
* @throws \Exception
*/
public function getGenerationInfo(string $imageGenId, bool $updateTimestamp = true): array {
public function getGenerationInfo(string $imageGenId, string $userId, bool $updateTimestamp = true): array {
// Check whether the task has completed:
try {
/** @var ImageGeneration $imageGeneration */
Expand All @@ -311,7 +313,7 @@ public function getGenerationInfo(string $imageGenId, bool $updateTimestamp = tr
throw new BaseException($this->l10n->t('Retrieving the image generation failed.'), Http::STATUS_INTERNAL_SERVER_ERROR);
}

$isOwner = ($imageGeneration->getUserId() === $this->userId);
$isOwner = ($imageGeneration->getUserId() === $userId);

if ($imageGeneration->getFailed() === true) {
throw new BaseException($this->l10n->t('Image generation failed.'), Http::STATUS_INTERNAL_SERVER_ERROR);
Expand Down Expand Up @@ -405,10 +407,11 @@ public function getImage(string $imageGenId, int $imageFileNameId): array {
* Cancel image generation
*
* @param string $imageGenId
* @param string $userId
* @return void
* @throws NotPermittedException
*/
public function cancelGeneration(string $imageGenId): void {
public function cancelGeneration(string $imageGenId, string $userId): void {
try {
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($imageGenId);
} catch (Exception | DoesNotExistException | MultipleObjectsReturnedException $e) {
Expand All @@ -418,14 +421,14 @@ public function cancelGeneration(string $imageGenId): void {

if ($imageGeneration) {
// Make sure the user is associated with the image generation
if ($imageGeneration->getUserId() !== $this->userId) {
if ($imageGeneration->getUserId() !== $userId) {
$this->logger->warning('User attempted deleting another user\'s image generation!', ['app' => Application::APP_ID]);
return;
}

// Get the generation task if it exists
try {
$task = $this->textToImageManager->getUserTasksByApp($this->userId, Application::APP_ID, $imageGenId);
$task = $this->textToImageManager->getUserTasksByApp($userId, Application::APP_ID, $imageGenId);
} catch (RuntimeException $e) {
$this->logger->debug('Task cancellation failed or it does not exist: ' . $e->getMessage(), ['app' => Application::APP_ID]);
$task = [];
Expand Down Expand Up @@ -484,10 +487,11 @@ public function cancelGeneration(string $imageGenId): void {
*
* @param string $imageGenId
* @param array $fileVisStatusArray
* @param string $userId
* @return void
* @throws BaseException
*/
public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStatusArray): void {
public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStatusArray, string $userId): void {
try {
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($imageGenId);
} catch (DoesNotExistException $e) {
Expand All @@ -498,7 +502,7 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
throw new BaseException('Internal server error.', Http::STATUS_INTERNAL_SERVER_ERROR);
}

if ($imageGeneration->getUserId() !== $this->userId) {
if ($imageGeneration->getUserId() !== $userId) {
$this->logger->warning('User attempted deleting another user\'s image generation!');
throw new BaseException('Unauthorized.', Http::STATUS_UNAUTHORIZED);
}
Expand All @@ -517,9 +521,10 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
* Notify when image generation is ready
*
* @param string $imageGenId
* @param string $userId
* @throws Exception
*/
public function notifyWhenReady(string $imageGenId): void {
public function notifyWhenReady(string $imageGenId, string $userId): void {
try {
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($imageGenId);
} catch (DoesNotExistException $e) {
Expand All @@ -539,7 +544,7 @@ public function notifyWhenReady(string $imageGenId): void {
throw new BaseException('Internal server error.', Http::STATUS_INTERNAL_SERVER_ERROR);
}

if ($imageGeneration->getUserId() !== $this->userId) {
if ($imageGeneration->getUserId() !== $userId) {
$this->logger->warning('User attempted enabling notifications of another user\'s image generation!');
throw new BaseException('Unauthorized.', Http::STATUS_UNAUTHORIZED);
}
Expand All @@ -548,7 +553,7 @@ public function notifyWhenReady(string $imageGenId): void {

// Just in case check if the image generation is already ready and, if so, notify the user immediately so that the result is not lost:
try {
$tasks = $this->textToImageManager->getUserTasksByApp($this->userId, Application::APP_ID, $imageGenId);
$tasks = $this->textToImageManager->getUserTasksByApp($userId, Application::APP_ID, $imageGenId);
} catch (RuntimeException $e) {
$this->logger->debug('Assistant meta task for the given generation id does not exist or could not be retrieved: ' . $e->getMessage(), ['app' => Application::APP_ID]);
return;
Expand Down

0 comments on commit dd2b2f1

Please sign in to comment.