The GraphRAGExtractor class is designed to extract triples (subject-relation-object) from text and enrich them by adding descriptions for entities and relationships to their properties using an LLM.
This functionality is similar to that of the SimpleLLMPathExtractor, but includes additional enhancements to handle entity, relationship descriptions. For guidance on implementation, you may look at similar existing extractors.
Here's a breakdown of its functionality:
Key Components:
llm: The language model used for extraction.
extract_prompt: A prompt template used to guide the LLM in extracting information.
parse_fn: A function to parse the LLM's output into structured data.
max_paths_per_chunk: Limits the number of triples extracted per text chunk.
num_workers: For parallel processing of multiple text nodes.
Main Methods:
__call__: The entry point for processing a list of text nodes.
acall: An asynchronous version of call for improved performance.
_aextract: The core method that processes each individual node.
Extraction Process:
For each input node (chunk of text):
It sends the text to the LLM along with the extraction prompt.
The LLM's response is parsed to extract entities, relationships, descriptions for entities and relations.
Entities are converted into EntityNode objects. Entity description is stored in metadata
Relationships are converted into Relation objects. Relationship description is stored in metadata.
These are added to the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY.
NOTE: In the current implementation, we are using only relationship descriptions. In the next implementation, we will utilize entity descriptions during the retrieval stage.
import asyncio
import nest_asyncio
nest_asyncio.apply()
from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display
from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
EntityNode,
KG_NODES_KEY,
KG_RELATIONS_KEY,
Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field
class GraphRAGExtractor(TransformComponent):
"""Extract triples from a graph.
Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.
Args:
llm (LLM):
The language model to use.
extract_prompt (Union[str, PromptTemplate]):
The prompt to use for extracting triples.
parse_fn (callable):
A function to parse the output of the language model.
num_workers (int):
The number of workers to use for parallel processing.
max_paths_per_chunk (int):
The maximum number of paths to extract per chunk.
"""
llm: LLM
extract_prompt: PromptTemplate
parse_fn: Callable
num_workers: int
max_paths_per_chunk: int
def __init__(
self,
llm: Optional[LLM] = None,
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
parse_fn: Callable = default_parse_triplets_fn,
max_paths_per_chunk: int = 10,
num_workers: int = 4,
) -> None:
"""Init params."""
from llama_index.core import Settings
if isinstance(extract_prompt, str):
extract_prompt = PromptTemplate(extract_prompt)
super().__init__(
llm=llm or Settings.llm,
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
parse_fn=parse_fn,
num_workers=num_workers,
max_paths_per_chunk=max_paths_per_chunk,
)
@classmethod
def class_name(cls) -> str:
return "GraphExtractor"
def __call__(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes."""
return asyncio.run(
self.acall(nodes, show_progress=show_progress, **kwargs)
)
async def _aextract(self, node: BaseNode) -> BaseNode:
"""Extract triples from a node."""
assert hasattr(node, "text")
text = node.get_content(metadata_mode="llm")
try:
llm_response = await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_paths_per_chunk,
)
entities, entities_relationship = self.parse_fn(llm_response)
except ValueError:
entities = []
entities_relationship = []
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
metadata = node.metadata.copy()
for entity, entity_type, description in entities:
metadata[
"entity_description"
] = description # Not used in the current implementation. But will be useful in future work.
entity_node = EntityNode(
name=entity, label=entity_type, properties=metadata
)
existing_nodes.append(entity_node)
metadata = node.metadata.copy()
for triple in entities_relationship:
subj, obj, rel, description = triple
subj_node = EntityNode(name=subj, properties=metadata)
obj_node = EntityNode(name=obj, properties=metadata)
metadata["relationship_description"] = description
rel_node = Relation(
label=rel,
source_id=subj_node.id,
target_id=obj_node.id,
properties=metadata,
)
existing_nodes.extend([subj_node, obj_node])
existing_relations.append(rel_node)
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
async def acall(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes async."""
jobs = []
for node in nodes:
jobs.append(self._aextract(node))
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text",
)
No comments:
Post a Comment