Wednesday, June 18, 2025

Detail on GraphRAGExtractor

The GraphRAGExtractor class is designed to extract triples (subject-relation-object) from text and enrich them by adding descriptions for entities and relationships to their properties using an LLM.

This functionality is similar to that of the SimpleLLMPathExtractor, but includes additional enhancements to handle entity, relationship descriptions. For guidance on implementation, you may look at similar existing extractors.

Here's a breakdown of its functionality:

Key Components:

llm: The language model used for extraction.

extract_prompt: A prompt template used to guide the LLM in extracting information.

parse_fn: A function to parse the LLM's output into structured data.

max_paths_per_chunk: Limits the number of triples extracted per text chunk.

num_workers: For parallel processing of multiple text nodes.

Main Methods:

__call__: The entry point for processing a list of text nodes.

acall: An asynchronous version of call for improved performance.

_aextract: The core method that processes each individual node.

Extraction Process:

For each input node (chunk of text):

It sends the text to the LLM along with the extraction prompt.

The LLM's response is parsed to extract entities, relationships, descriptions for entities and relations.

Entities are converted into EntityNode objects. Entity description is stored in metadata

Relationships are converted into Relation objects. Relationship description is stored in metadata.

These are added to the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY.

NOTE: In the current implementation, we are using only relationship descriptions. In the next implementation, we will utilize entity descriptions during the retrieval stage.


import asyncio

import nest_asyncio


nest_asyncio.apply()


from typing import Any, List, Callable, Optional, Union, Dict

from IPython.display import Markdown, display


from llama_index.core.async_utils import run_jobs

from llama_index.core.indices.property_graph.utils import (

    default_parse_triplets_fn,

)

from llama_index.core.graph_stores.types import (

    EntityNode,

    KG_NODES_KEY,

    KG_RELATIONS_KEY,

    Relation,

)

from llama_index.core.llms.llm import LLM

from llama_index.core.prompts import PromptTemplate

from llama_index.core.prompts.default_prompts import (

    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,

)

from llama_index.core.schema import TransformComponent, BaseNode

from llama_index.core.bridge.pydantic import BaseModel, Field



class GraphRAGExtractor(TransformComponent):

    """Extract triples from a graph.


    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.


    Args:

        llm (LLM):

            The language model to use.

        extract_prompt (Union[str, PromptTemplate]):

            The prompt to use for extracting triples.

        parse_fn (callable):

            A function to parse the output of the language model.

        num_workers (int):

            The number of workers to use for parallel processing.

        max_paths_per_chunk (int):

            The maximum number of paths to extract per chunk.

    """


    llm: LLM

    extract_prompt: PromptTemplate

    parse_fn: Callable

    num_workers: int

    max_paths_per_chunk: int


    def __init__(

        self,

        llm: Optional[LLM] = None,

        extract_prompt: Optional[Union[str, PromptTemplate]] = None,

        parse_fn: Callable = default_parse_triplets_fn,

        max_paths_per_chunk: int = 10,

        num_workers: int = 4,

    ) -> None:

        """Init params."""

        from llama_index.core import Settings


        if isinstance(extract_prompt, str):

            extract_prompt = PromptTemplate(extract_prompt)


        super().__init__(

            llm=llm or Settings.llm,

            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,

            parse_fn=parse_fn,

            num_workers=num_workers,

            max_paths_per_chunk=max_paths_per_chunk,

        )


    @classmethod

    def class_name(cls) -> str:

        return "GraphExtractor"


    def __call__(

        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any

    ) -> List[BaseNode]:

        """Extract triples from nodes."""

        return asyncio.run(

            self.acall(nodes, show_progress=show_progress, **kwargs)

        )


    async def _aextract(self, node: BaseNode) -> BaseNode:

        """Extract triples from a node."""

        assert hasattr(node, "text")


        text = node.get_content(metadata_mode="llm")

        try:

            llm_response = await self.llm.apredict(

                self.extract_prompt,

                text=text,

                max_knowledge_triplets=self.max_paths_per_chunk,

            )

            entities, entities_relationship = self.parse_fn(llm_response)

        except ValueError:

            entities = []

            entities_relationship = []


        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])

        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])

        metadata = node.metadata.copy()

        for entity, entity_type, description in entities:

            metadata[

                "entity_description"

            ] = description  # Not used in the current implementation. But will be useful in future work.

            entity_node = EntityNode(

                name=entity, label=entity_type, properties=metadata

            )

            existing_nodes.append(entity_node)


        metadata = node.metadata.copy()

        for triple in entities_relationship:

            subj, obj, rel, description = triple

            subj_node = EntityNode(name=subj, properties=metadata)

            obj_node = EntityNode(name=obj, properties=metadata)

            metadata["relationship_description"] = description

            rel_node = Relation(

                label=rel,

                source_id=subj_node.id,

                target_id=obj_node.id,

                properties=metadata,

            )


            existing_nodes.extend([subj_node, obj_node])

            existing_relations.append(rel_node)


        node.metadata[KG_NODES_KEY] = existing_nodes

        node.metadata[KG_RELATIONS_KEY] = existing_relations

        return node


    async def acall(

        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any

    ) -> List[BaseNode]:

        """Extract triples from nodes async."""

        jobs = []

        for node in nodes:

            jobs.append(self._aextract(node))


        return await run_jobs(

            jobs,

            workers=self.num_workers,

            show_progress=show_progress,

            desc="Extracting paths from text",

        )



No comments:

Post a Comment