-
Notifications
You must be signed in to change notification settings - Fork 4
/
_52_CassandraVectorDbTest.java
94 lines (80 loc) · 4.57 KB
/
_52_CassandraVectorDbTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package devoxx.demo._5_vectorsearch;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import devoxx.demo.devoxx.Product;
import devoxx.demo.utils.AbstractDevoxxTestSupport;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
public class _52_CassandraVectorDbTest extends AbstractDevoxxTestSupport {
Logger log = LoggerFactory.getLogger(_52_CassandraVectorDbTest.class);
@Test
@Disabled
public void cassandraVectorSearchTest() throws IOException {
// Connection to the Cassandra
try (CqlSession cqlSession = getCassandraSession()) {
log.info("Connected to Cassandra");
// Create a Table with Embeddings
cqlSession.execute(
"CREATE TABLE IF NOT EXISTS pet_supply_vectors (" +
" product_id TEXT PRIMARY KEY," +
" product_name TEXT," +
" product_vector vector<float, 14>)");
log.info("Table created.");
// Create a Search Index
cqlSession.execute(
"CREATE CUSTOM INDEX IF NOT EXISTS idx_vector " +
"ON pet_supply_vectors(product_vector) " +
"USING 'StorageAttachedIndex'");
log.info("Index Created.");
// Insert rows
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf1843','HealthyFresh - Chicken raw dog food',[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])");
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf1844','HealthyFresh - Beef raw dog food',[1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])");
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pt0021','Dog Tennis Ball Toy',[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0])");
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pt0041','Dog Ring Chew Toy',[0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0])");
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf7043','PupperSausage Bacon dog Treats',[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1])");
cqlSession.execute(
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf7044','PupperSausage Beef dog Treats',[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0])");
// Find By ID (primary KEY)
Row row = cqlSession.execute(SimpleStatement
.builder("SELECT * FROM pet_supply_vectors WHERE product_id = ?")
.addPositionalValue("pf1843").build()).one();
Product p = Optional.ofNullable(row)
.map(this::mapCassandraRow2Product)
.orElseThrow(() -> new RuntimeException("Product not found"));
log.info("Product Found ! looking for similar products");
// Semantic Search
ResultSet resultSet = cqlSession.execute(SimpleStatement
.builder("SELECT * FROM pet_supply_vectors " +
"ORDER BY product_vector ANN OF ? LIMIT 2;")
.addPositionalValue(p.vector())
.build());
List<Product> similarProducts = resultSet.all()
.stream().map(this::mapCassandraRow2Product).toList();
log.info("Similar Products : {}", similarProducts);
}
}
private Product mapCassandraRow2Product(Row row) {
return new Product(
row.getString("product_id"),
row.getString("product_name"),
row.getObject("product_vector"));
}
}