Sunday, October 6, 2024

PostgresDB Vector Search - Adding and retrieving


model_name = "sentence-transformers/all-mpnet-base-v2"

model_kwargs = {'device': 'cpu'}

encode_kwargs = {'normalize_embeddings': False}

embedding_function = HuggingFaceEmbeddings(

    model_name=model_name,

    model_kwargs=model_kwargs,

    encode_kwargs=encode_kwargs

)



vector_store = PGVector(

    embeddings=embedding_function,

    collection_name=os.environ["COLLECTION_NAME"],

    connection=os.environ["DB_CONNECTION_STRING"],

    use_jsonb=True,

)



def add_to_db_in_batches(batch_size=100):

    existing_ids = read_collection_ids()


    data_ids = [str(json.loads(item.page_content)["id"]) for item in data]


    new_ids = list(set(data_ids) - set(existing_ids))



    # print(new_ids)



    if len(new_ids) > 0:

        new_documents = [item for item in data if json.loads(item.page_content)["id"] in new_ids]



        total_products = len(new_documents)

        start_time = time.time()  # Start the timer

        

        for i in range(0, total_products, batch_size):

            batch_data = new_documents[i:i + batch_size]

            ids = [json.loads(item.page_content)["id"] for item in batch_data]

            vector_store.add_documents(batch_data, ids=ids)

            remaining = total_products - (i + len(batch_data))

            

            elapsed_time = time.time() - start_time

            batches_processed = (i // batch_size) + 1

            average_time_per_batch = elapsed_time / batches_processed if batches_processed > 0 else 0

            estimated_remaining_batches = (total_products // batch_size) - batches_processed

            estimated_remaining_time = average_time_per_batch * estimated_remaining_batches

            

            # Format estimated remaining time

            estimated_remaining_time_minutes = estimated_remaining_time // 60

            estimated_remaining_time_seconds = estimated_remaining_time % 60

            

            print(f'Added products {i + 1} to {min(i + len(batch_data), total_products)} to the database. '

                f'Remaining: {remaining}. Estimated remaining time: {int(estimated_remaining_time_minutes)} minutes and {int(estimated_remaining_time_seconds)} seconds.')


    else:

        pass




Now To Search it, below can be done 


import json

from typing import Annotated

from fastapi import Query

from pydantic import BaseModel, Field

from .setup import vector_store



class SearchParams(BaseModel):

    query:str = Field(..., max=150)

    k: int = Field(5, ge=5, le=1000)



def get_search_results(params: Annotated[SearchParams, Query()]):


    results = vector_store.similarity_search(

        query=params.query,

        k=params.k

    )



    response = [json.loads(result.page_content) for result in results]


    return response

No comments:

Post a Comment