如何通过Neo4j扩展GraphRAG实现高效文档检索

2024年09月13日 由 alex 发表 63 0

在本文中,我将介绍可扩展图形RAG系统的示例实现架构,该系统利用Neo4j存储和管理从文档中提取的图形数据。我们将使用OpenAI的GPT-4o模型处理文档,提取实体和关系,并将它们存储在Neo4j图形中,以便更容易处理大型数据集并使用诸如中心性等图算法来回答查询。中心性度量有助于基于节点的连接标识图中最重要的节点,这对于快速准确地检索最相关的信息很重要。在这个示例中,我们强调基于中心性的检索而不是基于社区的检索,以提高查询响应的相关性。


先决条件

确保你具备以下条件:

  • Python 3.9+
  • Docker
  • 必要的库:openai,py2neo,python-dotenv


你可以使用以下命令安装这些库:


pip install openai py2neo python-dotenv


此外,我们将使用Docker来运行一个Neo4j实例来管理图形数据。


项目概述

该项目采用面向对象的方法进行结构化,为关键组件管理使用不同的类。该系统处理文档、提取实体和关系,并将它们存储在Neo4j中。使用中心性度量,我们优先考虑图形中最重要的实体,帮助提高查询响应的准确性和相关性。


项目结构

  • app.py:入口点,用于协调文档处理和查询工作流程。
  • GraphManager (在graph_manager.py中):管理Neo4j操作,如构建图形、重新计算中心性度量和管理更新。
  • QueryHandler (在query_handler.py中):处理用户查询,并利用GPT模型根据图形数据和中心性度量提供响应。
  • DocumentProcessor (在document_processor.py中):将文档拆分成片段,提取实体和关系,并对其进行总结。
  • GraphDatabase (在graph_database.py中):管理与Neo4j数据库的连接。
  • logger.py (在logger.py中):提供日志记录工具,以追踪应用程序的进展。


使用Docker设置Neo4j

要在本地设置Neo4j,请运行以下命令构建和启动Docker容器:


sh build.sh
sh start.sh


这将在本地运行一个Neo4j实例,可以通过http://localhost:7474和bolt://localhost:7687访问。


从Python连接到Neo4j

我们将使用py2neo库连接到Neo4j数据库。graph_database.py中的GraphDatabaseConnection类处理此连接:


from py2neo import Graph
from logger import Logger
import os

DB_URL = os.getenv("DB_URL")
DB_USERNAME = os.getenv("DB_USERNAME")
DB_PASSWORD = os.getenv("DB_PASSWORD")

class GraphDatabaseConnection:
    logger = Logger("GraphDatabaseConnection").get_logger()
    def __init__(self, db_url=DB_URL, username=DB_USERNAME, password=DB_PASSWORD):
        self.db_url = db_url
        self.username = username
        self.password = password
        self.graph = self.connect()
    def connect(self):
        try:
            graph = Graph(self.db_url, auth=(self.username, self.password))
            self.logger.info("Connected to the database")
            return graph
        except Exception as e:
            self.logger.error(f"Error connecting to the database: {e}")
            return None
    def clear_database(self):
        if self.graph:
            self.graph.delete_all()
            self.logger.info("Deleted all data from the database")


文件处理与DocumentProcessor

DocumentProcessor类负责通过将文件拆分成块、提取关键实体和关系,并使用OpenAI的GPT模型对它们进行总结。


示例:文件处理


from logger import Logger

class DocumentProcessor:
    logger = Logger("DocumentProcessor").get_logger()
    def __init__(self, client, model):
        self.client = client
        self.model = model
    def split_documents(self, documents, chunk_size=600, overlap_size=100):
        chunks = []
        for document in documents:
            for i in range(0, len(document), chunk_size - overlap_size):
                chunk = document[i:i + chunk_size]
                chunks.append(chunk)
        self.logger.debug("Documents split into %d chunks", len(chunks))
        return chunks
    def extract_elements(self, chunks):
        elements = []
        for index, chunk in enumerate(chunks):
            self.logger.debug(
                f"Extracting elements and relationship strength from chunk {index + 1}")
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system",
                        "content": "Extract entities, relationships, and their strength from the following text. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships, and estimate a strength between 0.0 (very weak) and 1.0 (very strong). Format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]. Do not include any other text in your response. Use this exact format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]."},
                    {"role": "user", "content": chunk}
                ]
            )
            entities_and_relations = response.choices[0].message.content
            elements.append(entities_and_relations)
        self.logger.debug("Elements extracted")
        return elements
    def summarize_elements(self, elements):
        summaries = []
        for index, element in enumerate(elements):
            self.logger.debug(f"Summarizing element {index + 1}")
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "Summarize the following entities and relationships in a structured format. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships. Use '->' to represent relationships after the 'Relationships:' word."},
                    {"role": "user", "content": element}
                ]
            )
            summary = response.choices[0].message.content
            summaries.append(summary)
        self.logger.debug("Summaries created")
        return summaries


通过GraphManager进行图形管理

一旦提取出实体和关系,它们就会使用GraphManager类存储在图形中。该类负责构建和重新构建图形,计算中心性度量,并在添加新数据时管理更新。


示例:构建图形

GraphManager类中的build_graph_in_neo4j方法负责根据文档摘要创建节点和关系:


class GraphManager:
    def __init__(self, graph_database):
        self.graph = graph_database.graph
    def build_graph(self, summaries):
        if self.graph is None:
            self.logger.error("Graph database connection is not available.")
            return
        entities = {}
        for summary in summaries:
            lines = summary.split("\n")
            entities_section = False
            relationships_section = False
            for line in lines:
                if line.startswith("### Entities:") or line.startswith("**Entities:**") or line.startswith("Entities:"):
                    entities_section = True
                    relationships_section = False
                    continue
                elif line.startswith("### Relationships:") or line.startswith("**Relationships:**") or line.startswith("Relationships:"):
                    entities_section = False
                    relationships_section = True
                    continue
                if entities_section and line.strip():
                    if line[0].isdigit() and '.' in line:
                        entity_name = line.split(".", 1)[1].strip()
                    else:
                        entity_name = line.strip()
                    entity_name = self.normalize_entity_name(
                        entity_name.replace("**", ""))
                    node = Node("Entity", name=entity_name)
                    self.logger.debug(f"Creating node: {entity_name}")
                    self.graph.merge(node, "Entity", "name")
                    entities[entity_name] = node
                elif relationships_section and line.strip():
                    parts = line.split("->")
                    if len(parts) >= 2:
                        source = self.normalize_entity_name(parts[0].strip())
                        target = self.normalize_entity_name(parts[-1].strip())
                        relationship_part = parts[1].strip()
                        relation_name = self.sanitize_relationship_name(
                            relationship_part.split("[")[0].strip())
                        strength = re.search(
                            r"\[strength:\s*(\d\.\d)\]", relationship_part)
                        weight = float(strength.group(1)) if strength else 1.0
                        self.logger.debug(
                            f"Parsed relationship: {source} -> {relation_name} -> {target} [weight: {weight}]")
                        if source in entities and target in entities:
                            if relation_name:
                                self.logger.debug(
                                    f"Creating relationship: {source} -> {relation_name} -> {target} with weight {weight}")
                                relation = Relationship(
                                    entities[source], relation_name, entities[target])
                                relation["weight"] = weight
                                self.graph.merge(relation)
                            else:
                                self.logger.debug(
                                    f"Skipping relationship: {source} -> {relation_name} -> {target} (relation name is empty)")
                        else:
                            self.logger.debug(
                                f"Skipping relationship: {source} -> {relation_name} -> {target} (one or both entities not found)")
# NOTE: More methods in the class, see the full code for details


增强查询响应的中心性度量

GraphManager还计算中心性度量,如度中心性、介数中心性和接近中心性。这些度量有助于优先考虑图中的关键实体,提高查询响应的相关性。


示例:计算中心性度量

calculate_centrality_measures方法计算图中每个节点的中心性度量:


def calculate_centrality_measures(self, graph_name="entityGraph"):
    self.reproject_graph(graph_name)
    check_query = f"CALL gds.graph.exists($graph_name) YIELD exists"
    exists_result = self.graph.run(
        check_query, graph_name=graph_name).evaluate()
    if not exists_result:
        raise Exception(f"Graph projection '{graph_name}' does not exist.")
    degree_centrality_query = f"""
    CALL gds.degree.stream($graph_name)
    YIELD nodeId, score
    RETURN gds.util.asNode(nodeId).name AS entityName, score
    ORDER BY score DESC
    LIMIT 10
    """
    degree_centrality_result = self.graph.run(
        degree_centrality_query, graph_name=graph_name).data()
    betweenness_centrality_query = f"""
    CALL gds.betweenness.stream($graph_name)
    YIELD nodeId, score
    RETURN gds.util.asNode(nodeId).name AS entityName, score
    ORDER BY score DESC
    LIMIT 10
    """
    betweenness_centrality_result = self.graph.run(
        betweenness_centrality_query, graph_name=graph_name).data()
    closeness_centrality_query = f"""
    CALL gds.closeness.stream($graph_name)
    YIELD nodeId, score
    RETURN gds.util.asNode(nodeId).name AS entityName, score
    ORDER BY score DESC
    LIMIT 10
    """
    closeness_centrality_result = self.graph.run(
        closeness_centrality_query, graph_name=graph_name).data()
    centrality_data = {
        "degree": degree_centrality_result,
        "betweenness": betweenness_centrality_result,
        "closeness": closeness_centrality_result
    }
    return centrality_data


使用QueryHandler处理查询

QueryHandler类利用中心性度量的结果和OpenAI的GPT模型来生成更相关和准确的响应,以回答用户的查询。


示例:处理查询


from graph_manager import GraphManager
from openai import OpenAI
from logger import Logger

class QueryHandler:
    logger = Logger("QueryHandler").get_logger()
    def __init__(self, graph_manager: GraphManager, client: OpenAI, model: str):
        self.graph_manager = graph_manager
        self.client = client
        self.model = model
    def ask_question(self, query):
        centrality_data = self.graph_manager.calculate_centrality_measures()
        centrality_summary = self.graph_manager.summarize_centrality_measures(
            centrality_data)
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "Use the centrality measures to answer the following query."},
                {"role": "user", "content": f"Query: {query} Centrality Summary: {centrality_summary}"}
            ]
        )
        self.logger.debug("Query answered: %s",
                          response.choices[0].message.content)
        final_answer = response.choices[0].message.content
        return final_answer


通过关注最中心的实体,系统能生成更好、更具上下文感知能力的答案。


使用新文档重新索引

当添加新文档时,图形会重新索引以更新实体、关系和中心性度量。根目录app.py中的reindex_with_new_documents函数处理这个过程。


示例:重新索引


def reindex_with_new_documents(new_documents, graph_manager: GraphManager):
    chunks = document_processor.split_documents(new_documents)
    elements_file = 'data/new_elements_data.pkl'
    summaries_file = 'data/new_summaries_data.pkl'
    elements = load_or_run(
        elements_file, document_processor.extract_elements, chunks)
    summaries = load_or_run(
        summaries_file, document_processor.summarize_elements, elements)
    graph_manager.build_graph(summaries)
    graph_manager.reproject_graph()


这样确保了图形与新数据保持更新,并且中心性度量被重新计算。


运行应用程序

设置环境后,运行应用程序:


python app.py


这将会:

  1. 对初始文档进行索引。
  2. 处理用户查询以提取主要主题。
  3. 使用新文档对图形进行重新索引。
  4. 基于更新后的图形回答另一个查询。


结论

通过使用Neo4j并采用基于类的方法以清晰地分离关注点,我们建立了一个可扩展和高效的GraphRAG流程。该系统可以处理更大的数据集,利用图算法来增强查询响应,并在添加新数据时不断更新图形。


此设计使你可以进一步扩展系统,包括其他算法或更大的数据集,并将其根据特定的业务需求进行定制。

文章来源:https://medium.com/thedeephub/how-to-scale-graphrag-with-neo4j-for-efficient-document-querying-f8be1ae4feb3
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消