from sentence_transformers import CrossEncoder
# Initialize once
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
async def search_with_reranking(query: str, limit: int = 5) -> list:
# Stage 1: Fast vector retrieval (get 4x candidates)
candidate_limit = min(limit * 4, 20)
query_embedding = await embedder.embed_query(query)
candidates = await db.query(
"SELECT content, metadata FROM chunks ORDER BY embedding $1 LIMIT $2",
query_embedding, candidate_limit
)
# Stage 2: Re-rank with cross-encoder
pairs = [[query, row['content']] for row in candidates]
scores = reranker.predict(pairs)
# Sort by reranker scores and return top N
reranked = sorted(
zip(candidates, scores),
key=lambda x: x[1],
reverse=True
)[:limit]
return [doc for doc, score in reranked]
No comments:
Post a Comment