Skip to content

Commit

Permalink
Add backwards compatibility for unstructured tensor search (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
vicilliar committed Aug 6, 2024
1 parent 857b33b commit 00665f1
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/marqo/core/unstructured_vespa_index/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
RANK_PROFILE_EMBEDDING_SIMILARITY_MODIFIERS_2_9 = 'embedding_similarity_modifiers'

# Note field names are also used as query inputs, so make sure these reserved names have a marqo__ prefix
# QUERY_INPUT_EMBEDDING = 'embedding_query'
QUERY_INPUT_EMBEDDING_2_10 = 'embedding_query' # Keep for backwards compatibility
QUERY_INPUT_EMBEDDING = "marqo__query_embedding" # TODO: see if this change from 'embedding_query' to 'embedding_query' changes anything
QUERY_INPUT_BM25_AGGREGATOR = 'marqo__bm25_aggregator'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ def _to_vespa_tensor_query(self, marqo_query: MarqoTensorQuery) -> Dict[str, Any
else:
ranking = unstructured_common.RANK_PROFILE_EMBEDDING_SIMILARITY

if self._marqo_index_version >= self._HYBRID_SEARCH_MINIMUM_VERSION:
query_input_embedding_parameter = unstructured_common.QUERY_INPUT_EMBEDDING
else:
query_input_embedding_parameter = unstructured_common.QUERY_INPUT_EMBEDDING_2_10

query_inputs = {
unstructured_common.QUERY_INPUT_EMBEDDING: marqo_query.vector_query
query_input_embedding_parameter: marqo_query.vector_query
}

if score_modifiers:
Expand All @@ -102,8 +107,7 @@ def _to_vespa_tensor_query(self, marqo_query: MarqoTensorQuery) -> Dict[str, Any

return query

@staticmethod
def _get_tensor_search_term(marqo_query: MarqoTensorQuery) -> str:
def _get_tensor_search_term(self, marqo_query: MarqoTensorQuery) -> str:
field_to_search = unstructured_common.VESPA_DOC_EMBEDDINGS

if marqo_query.ef_search is not None:
Expand All @@ -113,14 +117,19 @@ def _get_tensor_search_term(marqo_query: MarqoTensorQuery) -> str:
target_hits = marqo_query.limit + marqo_query.offset
additional_hits = 0

if self._marqo_index_version >= self._HYBRID_SEARCH_MINIMUM_VERSION:
query_input_embedding_parameter = unstructured_common.QUERY_INPUT_EMBEDDING
else:
query_input_embedding_parameter = unstructured_common.QUERY_INPUT_EMBEDDING_2_10

return (
f"("
f"{{"
f"targetHits:{target_hits}, "
f"approximate:{str(marqo_query.approximate)}, "
f'hnsw.exploreAdditionalHits:{additional_hits}'
f"}}"
f"nearestNeighbor({field_to_search}, {unstructured_common.QUERY_INPUT_EMBEDDING})"
f"nearestNeighbor({field_to_search}, {query_input_embedding_parameter})"
f")"
)

Expand Down
1 change: 1 addition & 0 deletions src/marqo/core/vespa_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from marqo.core.models.marqo_index import StructuredMarqoIndex, UnstructuredMarqoIndex
from marqo.core.models.score_modifier import ScoreModifier, ScoreModifierType
from marqo.core.models.marqo_index import *
from marqo.exceptions import InternalError


class VespaIndex(ABC):
Expand Down
56 changes: 55 additions & 1 deletion tests/tensor_search/integ_tests/test_search_unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from marqo.tensor_search.models.search import SearchContext
from tests.marqo_test import MarqoTestCase
from marqo.vespa.models import QueryResult


class TestSearchUnstructured(MarqoTestCase):
Expand Down Expand Up @@ -51,19 +52,25 @@ def setUpClass(cls) -> None:
treat_urls_and_pointers_as_images=True
)

index_with_version_2_10 = cls.unstructured_marqo_index_request(
marqo_version="2.10.0"
)

cls.indexes = cls.create_indexes([
default_text_index,
default_text_index_encoded_name,
default_image_index,
image_index_with_chunking,
image_index_with_random_model
image_index_with_random_model,
index_with_version_2_10
])

cls.default_text_index = default_text_index.name
cls.default_text_index_encoded_name = default_text_index_encoded_name.name
cls.default_image_index = default_image_index.name
cls.image_index_with_chunking = image_index_with_chunking.name
cls.image_index_with_random_model = image_index_with_random_model.name
cls.index_with_version_2_10 = index_with_version_2_10.name

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -1375,3 +1382,50 @@ def test_lexical_query_can_not_be_none(self):
with self.assertRaises(InvalidArgError):
res = tensor_search.search(text=None, config=self.config, index_name=self.default_text_index,
search_method=SearchMethod.LEXICAL)

def test_tensor_search_with_version_below_2_11_query_input_embedding(self):
"""
If the unstructured index is version 2.10 or below, the query will have query input:
embedding_query instead of marqo__query_embedding
"""

mock_vespa_client_query = mock.MagicMock()
mock_vespa_client_query.return_value = QueryResult(
**{'root': {
'id': 'toplevel',
'relevance': 1.0,
'fields': {'totalCount': 2},
'coverage': {'coverage': 100, 'documents': 2, 'full': True, 'nodes': 1, 'results': 1, 'resultsFull': 1},
'children': [{'id': 'index:content_default/0/c81e728d5f3b597225351eac',
'relevance': 0.39966427718009545,
'source': 'content_default',
'fields': {
'matchfeatures': {'closest(marqo__embeddings)': {'type': 'tensor<float>(p{})', 'cells': {'1': 1.0}}},
'sddocname': 'aa4f36de0c4f4433a8c31e4143b28029b',
'marqo__id': '2',
'marqo__strings': ['defgh', 'on the mat'],
'marqo__chunks': ['abc::defgh', 'this_cat_sat::on the mat'],
'marqo__short_string_fields': {'abc': 'defgh', 'this_cat_sat': 'on the mat'}}}]
}}
)

@mock.patch("marqo.vespa.vespa_client.VespaClient.query", mock_vespa_client_query)
def run():
res = tensor_search.search(
config=self.config,
index_name=self.index_with_version_2_10,
text="dogs",
search_method="TENSOR",
)
return res

res = run()

call_args = mock_vespa_client_query.call_args_list
self.assertEqual(len(call_args), 1)

vespa_query_kwargs = call_args[0][1]
self.assertIn("nearestNeighbor(marqo__embeddings, embedding_query)",
vespa_query_kwargs["yql"])
self.assertIn("embedding_query", vespa_query_kwargs["query_features"])
self.assertNotIn("marqo__query_embedding", vespa_query_kwargs["query_features"])

0 comments on commit 00665f1

Please sign in to comment.