Skip to content

Commit

Permalink
Support S3 custom model without explicit credentials (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
farshidz committed Aug 27, 2024
1 parent b43bab5 commit 9210ad2
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 17 deletions.
8 changes: 1 addition & 7 deletions src/marqo/s2_inference/model_downloading/from_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@ def get_presigned_s3_url(location: S3Location, auth: Optional[S3Auth] = None):
TODO: add link to proper usage in error messages
"""
if auth is None:
raise ModelDownloadError(
"Error retrieving private model. s3 authorisation information is required to "
"download a model from an s3 bucket. "
"If the model is publicly accessible, please use the model's publicly accessible URL."
)
s3_client = boto3.client('s3', **auth.dict())
s3_client = boto3.client('s3', **(auth.dict() if auth is not None else {}))
try:
return s3_client.generate_presigned_url('get_object', Params=location.dict(exclude_unset=True))
except NoCredentialsError:
Expand Down
4 changes: 3 additions & 1 deletion src/marqo/s2_inference/processing/custom_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def download_pretrained_from_s3(
raise ModelDownloadError(
"Received 403 error when trying to retrieve model from s3 storage. "
"Please check the request's s3 credentials and try again. "
)
) from e
else:
raise e

def download_pretrained_from_url(
url: str,
Expand Down
8 changes: 4 additions & 4 deletions src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _get_max_vectorise_batch_size() -> int:
except (ValueError, TypeError) as e:
value_error_msg = f"`{validation_error_msg} Current value: `{max_batch_size_value}`. Reason: {e}"
logger.error(value_error_msg)
raise ConfigurationError(value_error_msg)
raise ConfigurationError(value_error_msg) from e
if batch_size < 1:
batch_size_too_small_msg = f"`{validation_error_msg} Current value: `{max_batch_size_value}`."
logger.error(batch_size_too_small_msg)
Expand Down Expand Up @@ -298,17 +298,17 @@ def _update_available_models(model_cache_key: str, model_name: str, validated_mo
f"Unable to load model={model_name} on device={device} with normalization={normalize_embeddings}. "
f"If you are trying to load a custom model, "
f"please check that model_properties={validated_model_properties} is correct "
f"and Marqo has access to the weights file.")
f"and Marqo has access to the weights file.") from e

else:
most_recently_used_time = datetime.datetime.now()
logger.debug(f'renewed {model_name} on device {device} with new most recently time={most_recently_used_time}.')
try:
_available_models[model_cache_key][AvailableModelsKey.most_recently_used_time] = most_recently_used_time
except KeyError:
except KeyError as e:
raise ModelNotInCacheError(f"Marqo cannot renew model {model_name} on device {device} with normalization={normalize_embeddings}. "
f"Maybe another thread is updating the model cache at the same time."
f"Please wait for 10 seconds and send the request again.\n")
f"Please wait for 10 seconds and send the request again.\n") from e


def validate_model_properties(model_name: str, model_properties: dict) -> dict:
Expand Down
8 changes: 4 additions & 4 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,15 +1664,15 @@ def vectorise_jobs(jobs: List[VectorisedJobs]) -> Dict[JHash, Dict[str, List[flo
except (s2_inference_errors.UnknownModelError,
s2_inference_errors.InvalidModelPropertiesError,
s2_inference_errors.ModelLoadError,
s2_inference.ModelDownloadError) as model_error:
s2_inference.ModelDownloadError) as e:
raise api_exceptions.BadRequestError(
message=f'Problem vectorising query. Reason: {str(model_error)}',
message=f'Problem vectorising query. Reason: {str(e)}',
link=marqo_docs.list_of_models()
)
) from e

except s2_inference_errors.S2InferenceError as e:
# TODO: differentiate image processing errors from other types of vectorise errors
raise api_exceptions.InvalidArgError(message=f'Error vectorising content: {v.content}. Message: {e}')
raise api_exceptions.InvalidArgError(message=f'Error vectorising content: {v.content}. Message: {e}') from e
return result


Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.11.2"
__version__ = "2.11.3"

def get_version() -> str:
return f"{__version__}"

0 comments on commit 9210ad2

Please sign in to comment.