GraphRAG-01-Index-索引引擎

模块概览

职责

Index 模块是 GraphRAG 的核心索引构建引擎,负责将非结构化文本转换为结构化知识图谱。主要职责包括:

  • 文档加载与分块:从多种数据源(CSV、Text、JSON)加载文档,按配置进行文本分块
  • 实体关系抽取:使用 LLM 从文本中抽取实体(Entity)和关系(Relationship)
  • 图构建与合并:将抽取的实体关系合并为统一的知识图谱
  • 社区发现:使用 Leiden 算法进行图聚类,发现层级化社区结构
  • 社区报告生成:为每个社区生成自然语言摘要报告
  • 嵌入计算:计算实体、关系、文本单元的向量嵌入
  • 增量更新:支持在现有索引基础上增量添加新文档

输入与输出

输入

  • 原始文档文件(TXT、CSV、JSON)
  • 配置对象(GraphRagConfig)
  • 可选:现有索引(增量更新模式)

输出(存储在 output/ 目录下的 Parquet 文件):

  • documents.parquet:文档元数据
  • text_units.parquet:文本块
  • entities.parquet:实体列表
  • relationships.parquet:关系列表
  • communities.parquet:社区列表
  • community_reports.parquet:社区报告
  • covariates.parquet:协变量(可选,如 claims)

上下游依赖

上游依赖

  • Config 模块:加载配置
  • Storage 模块:读写文件
  • Language Model 模块:LLM 调用
  • Cache 模块:LLM 结果缓存

下游消费者

  • Query 模块:使用索引输出进行查询
  • 外部应用:直接读取 Parquet 文件进行分析

生命周期

  1. 初始化阶段:加载配置、创建存储、初始化缓存
  2. 执行阶段:顺序执行工作流管道
  3. 清理阶段:关闭缓存连接、记录统计信息

模块架构图

flowchart TB
    subgraph API["索引 API 层"]
        BuildIndex[build_index]
    end

    subgraph Factory["管道工厂"]
        PipelineFactory[Pipeline Factory]
        WorkflowRegistry[Workflow Registry]
    end

    subgraph Pipeline["管道执行器"]
        Runner[Pipeline Runner]
        Context[Pipeline Context]
    end

    subgraph Workflows["工作流层"]
        direction LR
        WF1[load_input_documents]
        WF2[create_base_text_units]
        WF3[extract_graph]
        WF4[finalize_graph]
        WF5[create_communities]
        WF6[create_community_reports]
        WF7[generate_text_embeddings]
    end

    subgraph Operations["操作层"]
        direction TB
        OP1[chunk_text]
        OP2[extract_graph]
        OP3[summarize_descriptions]
        OP4[cluster_graph]
        OP5[embed_text]
        OP6[create_graph]
        OP7[finalize_entities]
    end

    subgraph Utils["工具层"]
        DeriveFromRows[derive_from_rows<br/>异步并发执行]
        Hashing[gen_sha512_hash<br/>ID 生成]
        GraphUtils[图工具<br/>create_graph/layout_graph]
    end

    subgraph External["外部依赖"]
        LLM[LLM Service]
        Storage[Storage Layer]
        Cache[Cache Layer]
        VectorStore[Vector Store]
    end

    BuildIndex --> PipelineFactory
    PipelineFactory --> Runner
    Runner --> Context
    Runner --> Workflows
    
    WF1 --> OP1
    WF2 --> OP1
    WF3 --> OP2
    WF3 --> OP3
    WF4 --> OP7
    WF5 --> OP4
    WF5 --> OP6
    WF6 --> LLM
    WF7 --> OP5
    
    Operations --> Utils
    Operations --> LLM
    Operations --> Cache
    Operations --> Storage
    OP5 --> VectorStore
    
    Context --> Storage
    Context --> Cache

架构图说明

层次职责

API 层

  • 提供 build_index 函数作为统一入口
  • 接受配置、方法类型(Standard/Fast/Update)、回调函数
  • 返回 PipelineRunResult 列表

管道工厂

  • 根据索引方法(Standard/Fast/Update)选择工作流序列
  • 注册和管理所有工作流函数
  • 创建 Pipeline 对象

管道执行器

  • 按顺序执行工作流
  • 管理执行上下文(存储、缓存、回调)
  • 捕获并聚合错误

工作流层

  • 每个工作流对应一个高层次业务逻辑
  • 读取输入数据,调用操作层,写入输出数据
  • 独立错误边界

操作层

  • 原子性数据处理逻辑
  • 可被多个工作流复用
  • 封装复杂算法(如 Leiden 聚类、LLM 调用)

核心工作流详解

1. load_input_documents 工作流

功能说明

加载原始文档并进行初步处理。

请求结构

输入(从配置和存储中读取):

字段 类型 说明
input.type str 输入类型:fileblob
input.file_type str 文件类型:textcsvjson
input.base_dir str 输入目录路径
input.file_pattern str 文件匹配模式(如 *.txt
input.encoding str 文件编码(默认 utf-8

响应结构

输出(写入 documents.parquet):

字段 类型 必填 说明
id str 文档唯一 ID(SHA-512 哈希)
title str 文档标题(文件名或指定列)
text str 文档正文
metadata dict 文档元数据(JSON)

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 1. 加载输入配置
    input_config = config.input
    
    # 2. 创建输入数据加载器
    output = await load_input_documents(input_config, context.input_storage)
    
    # 3. 写入文档表到存储
    await write_table_to_storage(output, "documents", context.output_storage)
    
    return WorkflowFunctionOutput(result=output)


async def load_input_documents(
    config: InputConfig, storage: PipelineStorage
) -> pd.DataFrame:
    # 根据文件类型创建输入加载器
    documents_df = await create_input(config, storage)
    
    # 生成文档 ID(基于内容哈希)
    documents_df["id"] = documents_df.apply(
        lambda row: gen_sha512_hash(row, ["title", "text"]), axis=1
    )
    
    return documents_df

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: load_input_documents
    participant Factory as Input Factory
    participant Loader as File/Blob Loader
    participant Storage as Storage Layer
    
    WF->>Factory: create_input(config, storage)
    Factory->>Factory: 判断输入类型 (file/blob/csv/json)
    
    alt 文件类型: text
        Factory->>Loader: load_text_files(base_dir, pattern)
        loop 每个文件
            Loader->>Storage: 读取文件内容
            Storage-->>Loader: 文件文本
            Loader->>Loader: 解析为 Document
        end
        Loader-->>Factory: List[Document]
    else 文件类型: csv
        Factory->>Loader: load_csv(base_dir, pattern)
        Loader->>Storage: 读取 CSV 文件
        Storage-->>Loader: CSV 数据
        Loader->>Loader: 转换为 DataFrame
        Loader-->>Factory: DataFrame
    end
    
    Factory-->>WF: documents DataFrame
    
    WF->>WF: 生成文档 ID (SHA-512 哈希)
    WF->>Storage: write_table_to_storage(documents, "documents")
    Storage-->>WF: 写入完成

边界与异常

  • 重复文档:基于内容哈希,重复文档会被去重
  • 文件读取失败:记录错误,跳过该文件,继续处理其他文件
  • 编码错误:尝试多种编码(utf-8、gbk、latin-1)

2. create_base_text_units 工作流

功能说明

将文档分割为固定大小的文本块(text units),作为后续处理的基本单元。

请求结构

输入(从配置中读取):

字段 类型 必填 默认 说明
chunks.size int 300 每个文本块的 token 数
chunks.overlap int 100 相邻块之间的重叠 token 数
chunks.encoding_model str cl100k_base 编码模型(tiktoken)
chunks.strategy str tokens 分块策略(tokens/sentences)
chunks.group_by_columns list[str] [] 分组列(如按文档分组)

响应结构

输出(写入 text_units.parquet):

字段 类型 必填 说明
id str 文本块唯一 ID(SHA-512 哈希)
text str 文本块内容
n_tokens int Token 数量
document_ids list[str] 来源文档 ID 列表

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 1. 加载文档数据
    documents = await load_table_from_storage("documents", context.output_storage)
    
    # 2. 执行文本分块
    chunks = config.chunks
    output = create_base_text_units(
        documents,
        context.callbacks,
        group_by_columns=chunks.group_by_columns,
        size=chunks.size,
        overlap=chunks.overlap,
        encoding_model=chunks.encoding_model,
        strategy=chunks.strategy,
    )
    
    # 3. 写入文本块表
    await write_table_to_storage(output, "text_units", context.output_storage)
    
    return WorkflowFunctionOutput(result=output)


def create_base_text_units(
    documents: pd.DataFrame,
    callbacks: WorkflowCallbacks,
    group_by_columns: list[str],
    size: int,
    overlap: int,
    encoding_model: str,
    strategy: ChunkStrategyType,
) -> pd.DataFrame:
    # 1. 按指定列分组(如按文档分组)
    aggregated = documents.groupby(group_by_columns).agg({"text": list})
    
    # 2. 对每组文本进行分块
    def chunker(row):
        chunked = chunk_text(
            text=row["text"],
            size=size,
            overlap=overlap,
            encoding_model=encoding_model,
            strategy=strategy,
            callbacks=callbacks,
        )
        row["chunks"] = chunked
        return row
    
    aggregated = aggregated.apply(chunker, axis=1)
    
    # 3. 展开分块结果
    aggregated = aggregated.explode("chunks")
    
    # 4. 生成文本块 ID
    aggregated["id"] = aggregated.apply(
        lambda row: gen_sha512_hash(row, ["chunk"]), axis=1
    )
    
    return aggregated

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: create_base_text_units
    participant Storage as Storage Layer
    participant Chunker as chunk_text Operation
    participant Tokenizer as Tokenizer
    
    WF->>Storage: load_table_from_storage("documents")
    Storage-->>WF: documents DataFrame
    
    WF->>WF:  group_by_columns 分组
    
    loop 每个分组
        WF->>Chunker: chunk_text(text, size, overlap, ...)
        Chunker->>Tokenizer: 编码文本为 tokens
        Tokenizer-->>Chunker: token 列表
        Chunker->>Chunker:  size  overlap 切分
        Chunker-->>WF: 分块列表
    end
    
    WF->>WF: 展开分块结果为行
    WF->>WF: 生成分块 ID (SHA-512)
    
    WF->>Storage: write_table_to_storage(text_units, "text_units")
    Storage-->>WF: 写入完成

边界与异常

  • Token 超限:如果单个文档超过最大 token 限制,会被分割为多个块
  • 重叠处理:相邻块之间重叠部分确保语义连续性
  • 空文本:空白或纯空格文本会被过滤

3. extract_graph 工作流

功能说明

使用 LLM 从文本块中抽取实体(Entity)和关系(Relationship)。

请求结构

输入(从配置中读取):

字段 类型 必填 默认 说明
extract_graph.model_id str default 使用的 LLM 模型 ID
extract_graph.prompt str 默认提示词 实体抽取提示词模板
extract_graph.max_gleanings int 1 迭代抽取次数
extract_graph.entity_types list[str] ["organization", "person", "geo", "event"] 抽取的实体类型
extract_graph.concurrent_requests int 25 并发 LLM 请求数

响应结构

输出 1: entities.parquet

字段 类型 必填 说明
title str 实体名称(唯一标识)
type str 实体类型
description list[str] 实体描述列表(多个文本块中的描述)
text_unit_ids list[str] 提及该实体的文本块 ID 列表
frequency int 实体出现频率

输出 2: relationships.parquet

字段 类型 必填 说明
source str 源实体名称
target str 目标实体名称
description list[str] 关系描述列表
weight float 关系权重(累加)
text_unit_ids list[str] 提及该关系的文本块 ID 列表

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 1. 加载文本块数据
    text_units = await load_table_from_storage("text_units", context.output_storage)
    
    # 2. 获取 LLM 配置和策略
    extract_graph_llm_settings = config.get_language_model_config(
        config.extract_graph.model_id
    )
    extraction_strategy = config.extract_graph.resolved_strategy(
        config.root_dir, extract_graph_llm_settings
    )
    
    # 3. 执行实体关系抽取
    entities, relationships, raw_entities, raw_relationships = await extract_graph(
        text_units=text_units,
        callbacks=context.callbacks,
        cache=context.cache,
        extraction_strategy=extraction_strategy,
        extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
        entity_types=config.extract_graph.entity_types,
    )
    
    # 4. 写入抽取结果
    await write_table_to_storage(entities, "entities", context.output_storage)
    await write_table_to_storage(relationships, "relationships", context.output_storage)
    
    return WorkflowFunctionOutput(result={"entities": entities, "relationships": relationships})


async def extract_graph(
    text_units: pd.DataFrame,
    callbacks: WorkflowCallbacks,
    cache: PipelineCache,
    extraction_strategy: dict[str, Any],
    extraction_num_threads: int,
    entity_types: list[str],
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    # 1. 加载抽取策略(graph_intelligence)
    strategy_exec = _load_strategy(extraction_strategy["type"])
    
    # 2. 并发执行抽取任务
    async def run_strategy(row):
        text = row["text"]
        id = row["id"]
        
        # 调用策略函数(内部调用 LLM)
        result = await strategy_exec(
            [Document(text=text, id=id)],
            entity_types,
            cache,
            extraction_strategy,
        )
        
        return [result.entities, result.relationships, result.graph]
    
    # 使用 derive_from_rows 并发执行
    results = await derive_from_rows(
        text_units,
        run_strategy,
        callbacks,
        num_threads=extraction_num_threads,
    )
    
    # 3. 合并抽取结果
    entities = _merge_entities(results)
    relationships = _merge_relationships(results)
    
    return (entities, relationships, raw_entities, raw_relationships)


def _merge_entities(entity_dfs) -> pd.DataFrame:
    # 按实体名称和类型分组
    all_entities = pd.concat(entity_dfs, ignore_index=True)
    return all_entities.groupby(["title", "type"], sort=False).agg(
        description=("description", list),
        text_unit_ids=("source_id", list),
        frequency=("source_id", "count"),
    ).reset_index()


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
    # 按源和目标实体分组
    all_relationships = pd.concat(relationship_dfs)
    return all_relationships.groupby(["source", "target"], sort=False).agg(
        description=("description", list),
        text_unit_ids=("source_id", list),
        weight=("weight", "sum"),
    ).reset_index()

调用链分析

上层调用链

build_index (API)
  └─> run_pipeline (Pipeline Runner)
      └─> run_workflow (extract_graph Workflow)
          └─> extract_graph (Operation)
              └─> derive_from_rows (并发工具)
                  └─> run_strategy (策略函数)
                      └─> run_graph_intelligence (LLM 策略)
                          └─> GraphExtractor.extract (图抽取器)
                              └─> ChatModel.call (LLM 调用)

LLM 调用详情

# graphrag/index/operations/extract_graph/graph_intelligence_strategy.py
async def run_graph_intelligence(
    docs: list[Document],
    entity_types: EntityTypes,
    cache: PipelineCache,
    args: StrategyConfig,
) -> EntityExtractionResult:
    # 1. 创建 LLM 模型实例
    model = ModelManager().get_or_create_chat_model(
        name="entity_extraction",
        model_type=args["model_type"],
        config=args["model_config"],
    )
    
    # 2. 构建提示词
    prompt = build_entity_extraction_prompt(
        docs=docs,
        entity_types=entity_types,
        prompt_template=args.get("prompt"),
    )
    
    # 3. 调用 LLM
    response = await model.call(
        messages=[{"role": "user", "content": prompt}],
        callbacks=callbacks,
        cache=cache,
    )
    
    # 4. 解析响应
    entities, relationships = parse_llm_response(response)
    
    return EntityExtractionResult(entities=entities, relationships=relationships)

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: extract_graph
    participant Storage as Storage
    participant DFR as derive_from_rows
    participant Strategy as Graph Intelligence Strategy
    participant Cache as Cache
    participant LLM as LLM Service
    
    WF->>Storage: load_table_from_storage("text_units")
    Storage-->>WF: text_units DataFrame
    
    WF->>DFR: derive_from_rows(text_units, run_strategy, threads=25)
    
    loop 每个 text_unit(并发)
        DFR->>Strategy: run_graph_intelligence(document, entity_types, cache, config)
        
        Strategy->>Cache: 检查缓存 (hash(text + prompt))
        
        alt 缓存命中
            Cache-->>Strategy: 缓存的抽取结果
        else 缓存未命中
            Strategy->>Strategy: 构建实体抽取提示词
            Strategy->>LLM: ChatModel.call(prompt)
            LLM-->>Strategy: LLM 响应 (JSON)
            Strategy->>Strategy: 解析 JSON  Entity/Relationship 列表
            Strategy->>Cache: 存储结果到缓存
        end
        
        Strategy-->>DFR: EntityExtractionResult
    end
    
    DFR-->>WF: List[entities, relationships]
    
    WF->>WF: _merge_entities (按名称和类型分组)
    WF->>WF: _merge_relationships (按源和目标分组)
    
    WF->>Storage: write_table_to_storage(entities, "entities")
    WF->>Storage: write_table_to_storage(relationships, "relationships")

边界与异常

  • LLM 调用失败

    • 重试策略:指数退避,最多 10 次
    • 超时:默认 180 秒
    • 失败记录:记录到日志,继续处理其他文本块
  • JSON 解析失败

    • 使用 json-repair 库修复损坏的 JSON
    • 修复失败则跳过该文本块
  • 实体冲突

    • 同名不同类型:合并为多类型实体
    • 描述冲突:保留所有描述作为列表

4. create_communities 工作流

功能说明

使用 Leiden 算法对知识图谱进行层级聚类,发现社区结构。

请求结构

输入(从配置中读取):

字段 类型 必填 默认 说明
cluster_graph.max_cluster_size int 10 单个社区最大实体数
cluster_graph.use_lcc bool True 是否仅使用最大连通分量
cluster_graph.seed int 随机 随机种子(用于可复现)

响应结构

输出(写入 communities.parquet):

字段 类型 必填 说明
id str 社区唯一 ID(UUID)
human_readable_id int 社区编号
community int 社区 ID(层级内唯一)
level int 层级级别(0=最底层)
title str 社区标题(如 “Community 0”)
parent int 父社区 ID(-1 表示根)
children list[int] 子社区 ID 列表
entity_ids list[str] 社区包含的实体 ID 列表
relationship_ids list[str] 社区内部的关系 ID 列表
text_unit_ids list[str] 社区涉及的文本块 ID 列表
size int 社区大小(实体数)
period str 创建日期(ISO 8601)

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 1. 加载实体和关系数据
    entities = await load_table_from_storage("entities", context.output_storage)
    relationships = await load_table_from_storage("relationships", context.output_storage)
    
    # 2. 执行社区发现
    max_cluster_size = config.cluster_graph.max_cluster_size
    use_lcc = config.cluster_graph.use_lcc
    seed = config.cluster_graph.seed
    
    output = create_communities(
        entities,
        relationships,
        max_cluster_size=max_cluster_size,
        use_lcc=use_lcc,
        seed=seed,
    )
    
    # 3. 写入社区表
    await write_table_to_storage(output, "communities", context.output_storage)
    
    return WorkflowFunctionOutput(result=output)


def create_communities(
    entities: pd.DataFrame,
    relationships: pd.DataFrame,
    max_cluster_size: int,
    use_lcc: bool,
    seed: int | None = None,
) -> pd.DataFrame:
    # 1. 构建图对象
    graph = create_graph(relationships, edge_attr=["weight"])
    
    # 2. 执行 Leiden 聚类
    clusters = cluster_graph(
        graph,
        max_cluster_size=max_cluster_size,
        use_lcc=use_lcc,
        seed=seed,
    )
    
    # 3. 转换为 DataFrame
    communities = pd.DataFrame(
        clusters, columns=["level", "community", "parent", "title"]
    ).explode("title")
    
    # 4. 聚合实体 ID
    entity_ids = communities.merge(entities, on="title", how="inner")
    entity_ids = entity_ids.groupby("community").agg(entity_ids=("id", list))
    
    # 5. 聚合关系 ID(仅包含社区内部关系)
    for level in range(communities["level"].max() + 1):
        communities_at_level = communities[communities["level"] == level]
        sources = relationships.merge(communities_at_level, left_on="source", right_on="title")
        targets = sources.merge(communities_at_level, left_on="target", right_on="title")
        matched = targets[targets["community_x"] == targets["community_y"]]
        # 聚合关系 ID 和文本单元 ID
    
    # 6. 合并所有信息
    final_communities = communities.merge(entity_ids, on="community")
    final_communities["id"] = [str(uuid4()) for _ in range(len(final_communities))]
    final_communities["size"] = final_communities["entity_ids"].apply(len)
    
    return final_communities

cluster_graph 操作核心代码

# graphrag/index/operations/cluster_graph.py
def cluster_graph(
    graph: nx.Graph,
    max_cluster_size: int,
    use_lcc: bool,
    seed: int | None = None,
) -> Communities:
    # 1. 可选:提取最大连通分量
    if use_lcc:
        graph = stable_largest_connected_component(graph)
    
    # 2. 执行层级 Leiden 聚类
    community_mapping = _compute_leiden_communities(
        graph=graph,
        max_cluster_size=max_cluster_size,
        seed=seed,
    )
    
    # 3. 构建层级结构
    results = []
    for level in sorted(community_mapping.keys()):
        clusters_at_level = community_mapping[level]
        for cluster_id, members in clusters_at_level.items():
            parent_id = _find_parent_community(cluster_id, level, community_mapping)
            results.append({
                "level": level,
                "community": cluster_id,
                "parent": parent_id,
                "title": list(members),
            })
    
    return results


def _compute_leiden_communities(
    graph: nx.Graph,
    max_cluster_size: int,
    seed: int | None = None,
) -> dict[int, dict[int, set]]:
    # 使用 graspologic 库的 hierarchical_leiden
    from graspologic.partition import hierarchical_leiden
    
    community_mapping = hierarchical_leiden(
        graph,
        max_cluster_size=max_cluster_size,
        random_seed=seed,
    )
    
    return community_mapping

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: create_communities
    participant Storage as Storage
    participant CreateGraph as create_graph Operation
    participant Cluster as cluster_graph Operation
    participant Leiden as Leiden Algorithm (graspologic)
    
    WF->>Storage: load_table_from_storage("entities")
    Storage-->>WF: entities DataFrame
    
    WF->>Storage: load_table_from_storage("relationships")
    Storage-->>WF: relationships DataFrame
    
    WF->>CreateGraph: create_graph(relationships, edge_attr=["weight"])
    CreateGraph->>CreateGraph: 构建 NetworkX Graph
    CreateGraph-->>WF: NetworkX Graph 对象
    
    WF->>Cluster: cluster_graph(graph, max_cluster_size, use_lcc, seed)
    
    alt use_lcc = True
        Cluster->>Cluster: stable_largest_connected_component(graph)
    end
    
    Cluster->>Leiden: hierarchical_leiden(graph, max_cluster_size, seed)
    Leiden->>Leiden: 执行多层 Leiden 聚类
    Leiden-->>Cluster: community_mapping (层级社区映射)
    
    Cluster->>Cluster: 构建层级结构(level, community, parent, title
    Cluster-->>WF: Communities 列表
    
    WF->>WF: 转换为 DataFrame
    WF->>WF: 聚合实体 ID
    WF->>WF: 聚合关系 ID(按层级)
    WF->>WF: 计算社区大小和父子关系
    
    WF->>Storage: write_table_to_storage(communities, "communities")

边界与异常

  • 图不连通:如果 use_lcc=True,仅处理最大连通分量,其他节点被丢弃
  • 社区过大:Leiden 算法会自动细分,直到满足 max_cluster_size 约束
  • 随机性:如果未指定 seed,每次运行结果可能不同

5. create_community_reports 工作流

功能说明

为每个社区生成自然语言摘要报告,描述社区的主题、关键实体和重要关系。

请求结构

输入(从配置中读取):

字段 类型 必填 默认 说明
community_reports.model_id str default 使用的 LLM 模型 ID
community_reports.prompt str 默认提示词 社区报告生成提示词
community_reports.max_length int 2000 报告最大长度(token)
community_reports.concurrent_requests int 25 并发 LLM 请求数

响应结构

输出(写入 community_reports.parquet):

字段 类型 必填 说明
id str 报告唯一 ID
human_readable_id int 社区编号
community int 社区 ID
level int 层级级别
title str 报告标题
summary str 社区摘要
full_content str 完整报告内容
full_content_json dict 报告结构化数据(JSON)
rank float 报告重要性排名
rank_explanation str 排名解释
findings list[dict] 发现列表(每个包含摘要和解释)

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 1. 加载必要数据
    communities = await load_table_from_storage("communities", context.output_storage)
    entities = await load_table_from_storage("entities", context.output_storage)
    relationships = await load_table_from_storage("relationships", context.output_storage)
    text_units = await load_table_from_storage("text_units", context.output_storage)
    
    # 2. 准备上下文数据
    community_contexts = prepare_community_contexts(
        communities, entities, relationships, text_units
    )
    
    # 3. 并发生成报告
    reports = await generate_community_reports(
        community_contexts=community_contexts,
        callbacks=context.callbacks,
        cache=context.cache,
        model_config=config.community_reports,
        num_threads=config.community_reports.concurrent_requests,
    )
    
    # 4. 写入报告表
    await write_table_to_storage(reports, "community_reports", context.output_storage)
    
    return WorkflowFunctionOutput(result=reports)


async def generate_community_reports(
    community_contexts: pd.DataFrame,
    callbacks: WorkflowCallbacks,
    cache: PipelineCache,
    model_config: CommunityReportsConfig,
    num_threads: int,
) -> pd.DataFrame:
    # 并发生成报告
    async def generate_report(row):
        community_id = row["community"]
        context_text = row["context"]
        
        # 调用 LLM 生成报告
        report = await _generate_single_report(
            community_id=community_id,
            context=context_text,
            cache=cache,
            model_config=model_config,
        )
        
        return report
    
    reports = await derive_from_rows(
        community_contexts,
        generate_report,
        callbacks,
        num_threads=num_threads,
    )
    
    return pd.DataFrame(reports)


async def _generate_single_report(
    community_id: int,
    context: str,
    cache: PipelineCache,
    model_config: CommunityReportsConfig,
) -> dict:
    # 1. 构建提示词
    prompt = build_community_report_prompt(
        community_id=community_id,
        context=context,
        prompt_template=model_config.prompt,
        max_length=model_config.max_length,
    )
    
    # 2. 调用 LLM
    model = ModelManager().get_or_create_chat_model(
        name="community_reports",
        model_type=model_config.model_type,
        config=model_config.model_config,
    )
    
    response = await model.call(
        messages=[{"role": "user", "content": prompt}],
        cache=cache,
    )
    
    # 3. 解析响应
    report_data = parse_community_report_response(response)
    
    return {
        "community": community_id,
        "title": report_data["title"],
        "summary": report_data["summary"],
        "findings": report_data["findings"],
        "rank": report_data["rating"],
        "rank_explanation": report_data["rating_explanation"],
        "full_content": response,
        "full_content_json": report_data,
    }

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: create_community_reports
    participant Storage as Storage
    participant Prep as prepare_community_contexts
    participant DFR as derive_from_rows
    participant Cache as Cache
    participant LLM as LLM Service
    
    WF->>Storage: 加载 communities, entities, relationships, text_units
    Storage-->>WF: 所有必要的 DataFrames
    
    WF->>Prep: prepare_community_contexts(communities, entities, ...)
    Prep->>Prep: 为每个社区构建上下文字符串(实体列表、关系列表、文本片段)
    Prep-->>WF: community_contexts DataFrame
    
    WF->>DFR: derive_from_rows(community_contexts, generate_report, threads=25)
    
    loop 每个社区(并发)
        DFR->>DFR: _generate_single_report(community_id, context, cache, config)
        
        DFR->>DFR: 构建社区报告提示词
        DFR->>Cache: 检查缓存 (hash(community_id + context + prompt))
        
        alt 缓存命中
            Cache-->>DFR: 缓存的报告
        else 缓存未命中
            DFR->>LLM: ChatModel.call(prompt)
            LLM-->>DFR: LLM 响应 (JSON格式报告)
            DFR->>DFR: 解析 JSON(title, summary, findings, rating)
            DFR->>Cache: 存储报告到缓存
        end
        
        DFR-->>WF: 报告字典
    end
    
    WF->>WF: 转换为 DataFrame
    WF->>Storage: write_table_to_storage(community_reports, "community_reports")

边界与异常

  • 上下文截断:如果社区上下文超过 token 限制,会自动截断
  • 报告质量:报告质量依赖 LLM 能力,可能需要调优提示词
  • 格式错误:如果 LLM 返回非 JSON 格式,使用 json-repair 修复

6. generate_text_embeddings 工作流

功能说明

计算实体描述、关系描述、文本单元的向量嵌入,并存储到向量数据库。

请求结构

输入(从配置中读取):

字段 类型 必填 默认 说明
embed_text.model_id str default 嵌入模型 ID
embed_text.batch_size int 100 批量嵌入大小
embed_text.concurrent_requests int 25 并发请求数
vector_store.{embedding_name}.* dict {} 向量存储配置

响应结构

输出 1:在原表中添加嵌入列(如 entities.parquetdescription_embedding 列)

输出 2:写入向量存储(如 Azure AI Search、LanceDB)

字段 类型 说明
id str 文档 ID
embedding list[float] 嵌入向量(如 1536 维)
title str 文档标题(可选)
text str 原始文本(可选)

入口函数与核心代码

async def run_workflow(
    config: GraphRagConfig,
    context: PipelineRunContext,
) -> WorkflowFunctionOutput:
    # 为不同类型数据生成嵌入
    # 1. 实体描述嵌入
    entities = await load_table_from_storage("entities", context.output_storage)
    entities = await embed_text(
        input=entities,
        callbacks=context.callbacks,
        cache=context.cache,
        embed_column="description",
        strategy=config.embed_text.strategy,
        embedding_name="entity_description_embedding",
    )
    await write_table_to_storage(entities, "entities", context.output_storage)
    
    # 2. 关系描述嵌入(类似)
    # 3. 文本单元嵌入(类似)
    
    return WorkflowFunctionOutput(result=None)


async def embed_text(
    input: pd.DataFrame,
    callbacks: WorkflowCallbacks,
    cache: PipelineCache,
    embed_column: str,
    strategy: dict,
    embedding_name: str,
):
    # 1. 创建嵌入模型
    model = ModelManager().get_or_create_embedding_model(
        name=embedding_name,
        model_type=strategy["model_type"],
        config=strategy["model_config"],
    )
    
    # 2. 批量嵌入
    batch_size = strategy.get("batch_size", 100)
    
    async def embed_batch(texts: list[str]) -> list[list[float]]:
        # 检查缓存
        cached_embeddings = []
        uncached_texts = []
        
        for text in texts:
            cache_key = gen_sha512_hash({"text": text, "model": model.name})
            cached = await cache.get(cache_key)
            if cached:
                cached_embeddings.append(cached)
            else:
                uncached_texts.append(text)
        
        # 嵌入未缓存的文本
        if uncached_texts:
            new_embeddings = await model.embed(uncached_texts)
            # 存储到缓存
            for text, embedding in zip(uncached_texts, new_embeddings):
                cache_key = gen_sha512_hash({"text": text, "model": model.name})
                await cache.set(cache_key, embedding)
        
        return cached_embeddings + new_embeddings
    
    # 3. 分批处理
    embeddings = []
    for i in range(0, len(input), batch_size):
        batch = input[embed_column].iloc[i:i+batch_size].tolist()
        batch_embeddings = await embed_batch(batch)
        embeddings.extend(batch_embeddings)
    
    # 4. 添加嵌入列
    input[f"{embed_column}_embedding"] = embeddings
    
    # 5. 写入向量存储(如果配置)
    if strategy.get("vector_store"):
        vector_store = create_vector_store(strategy["vector_store"], embedding_name)
        await vector_store.load_documents(
            documents=[
                {"id": row["id"], "embedding": row[f"{embed_column}_embedding"], "title": row.get("title")}
                for _, row in input.iterrows()
            ]
        )
    
    return input

时序图

sequenceDiagram
    autonumber
    participant WF as Workflow: generate_text_embeddings
    participant Storage as Storage
    participant EmbedOp as embed_text Operation
    participant Cache as Cache
    participant Embedding as Embedding Model
    participant Vector as Vector Store
    
    WF->>Storage: load_table_from_storage("entities")
    Storage-->>WF: entities DataFrame
    
    WF->>EmbedOp: embed_text(entities, "description", ...)
    
    loop 分批处理(batch_size=100
        EmbedOp->>EmbedOp: 提取当前批次文本
        
        loop 每个文本
            EmbedOp->>Cache: 检查嵌入缓存
            alt 缓存命中
                Cache-->>EmbedOp: 缓存的嵌入向量
            else 缓存未命中
                Note right of EmbedOp: 累积未缓存文本
            end
        end
        
        alt 有未缓存文本
            EmbedOp->>Embedding: embed_batch(uncached_texts)
            Embedding-->>EmbedOp: 嵌入向量列表
            
            loop 每个新嵌入
                EmbedOp->>Cache: 存储嵌入到缓存
            end
        end
        
        EmbedOp->>EmbedOp: 合并缓存和新嵌入
    end
    
    EmbedOp->>EmbedOp: 添加 embedding 列到 DataFrame
    
    alt 配置了向量存储
        EmbedOp->>Vector: load_documents(id, embedding, title)
        Vector-->>EmbedOp: 写入完成
    end
    
    EmbedOp-->>WF: 带嵌入列的 DataFrame
    
    WF->>Storage: write_table_to_storage(entities, "entities")

关键数据结构 UML 图

classDiagram
    class PipelineRunContext {
        +PipelineStorage input_storage
        +PipelineStorage output_storage
        +PipelineStorage previous_storage
        +PipelineCache cache
        +WorkflowCallbacks callbacks
        +dict state
    }
    
    class Pipeline {
        +list~tuple~ workflows
        +names() list~str~
        +remove(name: str)
    }
    
    class WorkflowFunctionOutput {
        +Any result
        +list~Exception~ errors
    }
    
    class PipelineRunResult {
        +str workflow
        +Any result
        +list~Exception~ errors
        +float total_runtime
        +dict stats
    }
    
    class Document {
        +str id
        +str text
        +dict metadata
    }
    
    class TextUnit {
        +str id
        +str text
        +int n_tokens
        +list~str~ document_ids
    }
    
    class Entity {
        +str id
        +str title
        +str type
        +list~str~ description
        +list~float~ description_embedding
        +list~str~ text_unit_ids
        +int frequency
    }
    
    class Relationship {
        +str id
        +str source
        +str target
        +list~str~ description
        +float weight
        +list~str~ text_unit_ids
    }
    
    class Community {
        +str id
        +int community
        +int level
        +str title
        +int parent
        +list~int~ children
        +list~str~ entity_ids
        +list~str~ relationship_ids
        +list~str~ text_unit_ids
        +int size
    }
    
    class CommunityReport {
        +str id
        +int community
        +int level
        +str title
        +str summary
        +str full_content
        +dict full_content_json
        +float rank
        +list~dict~ findings
    }
    
    Pipeline --> WorkflowFunctionOutput : returns
    Pipeline --> PipelineRunContext : uses
    PipelineRunContext --> PipelineRunResult : produces
    
    Document --> TextUnit : splits into
    TextUnit --> Entity : extracts
    TextUnit --> Relationship : extracts
    Entity --> Community : grouped into
    Relationship --> Community : grouped into
    Community --> CommunityReport : generates

数据结构说明

PipelineRunContext

  • 工作流执行的上下文对象
  • 包含存储、缓存、回调的引用
  • 工作流间共享状态通过 state 字典传递

Entity 字段详解

  • title:实体名称,作为唯一标识
  • type:实体类型(如 organization、person、geo、event)
  • description:描述列表(来自多个文本块的描述合并)
  • description_embedding:描述的嵌入向量
  • text_unit_ids:提及该实体的所有文本块 ID
  • frequency:实体在文本中出现的频率

Relationship 字段详解

  • source/target:源和目标实体的名称(非 ID)
  • weight:关系权重(累加),表示关系强度
  • description:关系描述列表

Community 字段详解

  • level:层级级别(0=最细粒度,数字越大越粗粒度)
  • parent:父社区 ID(-1 表示根社区)
  • children:子社区 ID 列表(构建层级树)

性能优化要点

1. 并发控制

LLM 调用并发

  • 通过 concurrent_requests 参数控制(默认 25)
  • 使用 asyncio.Semaphore 限流
  • 避免触发 API 速率限制

批量嵌入

  • batch_size 参数控制批量大小(默认 100)
  • 减少 API 调用次数
  • 平衡内存占用和延迟

2. 缓存策略

LLM 调用缓存

  • 缓存键:hash(text + prompt + model_params)
  • 存储:文件缓存(JSON 或 Parquet)
  • 命中率:通常 > 60%(重复文本多)

嵌入缓存

  • 缓存键:hash(text + model_name)
  • 批量预加载:启动时加载缓存索引

3. 内存管理

分批处理

  • 文本分块:按文档分批
  • 实体抽取:按文本块分批
  • 嵌入计算:按 batch_size 分批

DataFrame 优化

  • 使用 category 类型存储重复字符串
  • 及时释放中间 DataFrame

4. I/O 优化

Parquet 写入

  • 使用压缩(默认 snappy)
  • 批量写入,避免频繁 I/O

Blob Storage

  • 启用 Azure SDK 的并发上传
  • 使用异步 I/O(aiofiles)

异常处理与容错

工作流级错误

独立错误边界

  • 每个工作流独立 try-catch
  • 工作流失败不中断后续工作流
  • 错误聚合到 PipelineRunResult.errors

示例代码

async def run_pipeline(...):
    for workflow_name, workflow_func in pipeline.workflows:
        try:
            result = await workflow_func(config, context)
            yield PipelineRunResult(workflow=workflow_name, result=result, errors=[])
        except Exception as e:
            logger.error(f"Workflow {workflow_name} failed: {e}")
            yield PipelineRunResult(workflow=workflow_name, result=None, errors=[e])

操作级错误

LLM 调用失败

  • 自动重试:指数退避,最多 10 次
  • 超时:180 秒(可配置)
  • 降级:记录错误,跳过当前文本块

文件读写失败

  • Blob Storage:依赖 Azure SDK 重试
  • 本地文件:不重试,直接抛出异常

最佳实践

1. 配置调优

文本分块

  • chunks.size:300-600 tokens(取决于文档类型)
  • chunks.overlap:100-150 tokens(保证语义连续性)

实体抽取

  • extract_graph.entity_types:根据领域定制(如金融领域:["company", "person", "product", "regulation"]
  • extract_graph.max_gleanings:1-2 次(多次迭代抽取)

社区发现

  • cluster_graph.max_cluster_size:10-50(平衡社区粒度)
  • cluster_graph.use_lcc:True(忽略孤立节点)

2. 性能优化

缓存启用

  • 生产环境:使用 file 类型缓存
  • 开发环境:可用 memory 类型缓存

并发调整

  • OpenAI API:concurrent_requests=25
  • Azure OpenAI:根据 TPM 限制调整

3. 增量更新

何时使用

  • 新增文档数量 < 10% 已有文档
  • 不涉及重大配置变更

注意事项

  • 增量更新会合并新旧实体,可能产生冗余
  • 定期全量重建(如每月)

常见问题

Q1: 索引构建时间过长

原因

  • LLM 调用延迟高
  • 文档数量大

解决方案

  • 启用缓存
  • 增加 concurrent_requests
  • 使用 fast 索引模式(跳过 LLM 抽取,使用 NLP)

Q2: 内存不足

原因

  • 单个文档过大
  • 并发度过高

解决方案

  • 减少 concurrent_requests
  • 增加 chunks.size(减少文本块数量)
  • 分批处理文档

Q3: 社区质量不佳

原因

  • max_cluster_size 设置不当
  • 图连通性差

解决方案

  • 调整 max_cluster_size(尝试 10-50)
  • 启用 use_lcc=False(包含孤立节点)
  • 检查关系抽取质量

工作流执行顺序总结

Standard 模式

  1. load_input_documents:加载文档
  2. create_base_text_units:文本分块
  3. create_final_documents:最终化文档元数据
  4. extract_graph:LLM 实体关系抽取
  5. finalize_graph:图合并与去重
  6. extract_covariates:抽取协变量(claims)
  7. create_communities:社区发现
  8. create_final_text_units:最终化文本块(添加实体关联)
  9. create_community_reports:生成社区报告
  10. generate_text_embeddings:计算嵌入

Fast 模式

  1. load_input_documents
  2. create_base_text_units
  3. create_final_documents
  4. extract_graph_nlp:NLP 实体关系抽取(无 LLM)
  5. prune_graph:图剪枝
  6. finalize_graph
  7. create_communities
  8. create_final_text_units
  9. create_community_reports_text:基于文本的社区报告(无 LLM)
  10. generate_text_embeddings

Update 模式

在 Standard/Fast 的基础上,追加:

  1. update_final_documents:合并新旧文档
  2. update_entities_relationships:合并新旧实体关系
  3. update_text_units:合并新旧文本块
  4. update_covariates:合并协变量
  5. update_communities:重新计算社区
  6. update_community_reports:更新社区报告
  7. update_text_embeddings:更新嵌入
  8. update_clean_state:清理临时状态