From b0ed876cb0ad5703d0cc96d992104f72b882ad7e Mon Sep 17 00:00:00 2001 From: bielie Date: Tue, 3 Feb 2026 01:13:57 +0800 Subject: [PATCH] recall: normalize entity weights + relation target --- modules/story-summary/vector/recall.js | 81 +++++++++++++++++++++----- 1 file changed, 66 insertions(+), 15 deletions(-) diff --git a/modules/story-summary/vector/recall.js b/modules/story-summary/vector/recall.js index bf5279f..6602da6 100644 --- a/modules/story-summary/vector/recall.js +++ b/modules/story-summary/vector/recall.js @@ -1,6 +1,7 @@ // Story Summary - Recall Engine // L1 chunk + L2 event 召回 // - 全量向量打分 +// - 实体权重归一化分配 // - 指数衰减加权 Query Embedding // - L0 floor 加权 // - RRF 混合检索(向量 + 文本) @@ -193,6 +194,19 @@ function cleanForRecall(text) { return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim(); } +// ═══════════════════════════════════════════════════════════════════════════ +// 从谓词提取关系对象 +// ═══════════════════════════════════════════════════════════════════════════ + +function extractRelationTarget(p) { + if (!p) return ''; + let m = String(p).match(/^对(.+)的看法$/); + if (m) return m[1].trim(); + m = String(p).match(/^与(.+)的关系$/); + if (m) return m[1].trim(); + return ''; +} + function buildExpDecayWeights(n, beta) { const last = n - 1; const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last))); @@ -296,19 +310,22 @@ function buildEntityLexicon(store, allEvents) { function buildFactGraph(facts) { const graph = new Map(); + const { name1 } = getContext(); + const userName = normalize(name1); for (const f of facts || []) { if (f?.retracted) continue; if (!isRelationFact(f)) continue; const s = normalize(f?.s); - const o = normalize(f?.o); - if (!s || !o) continue; + const target = normalize(extractRelationTarget(f?.p)); + if (!s || !target) continue; + if (s === userName || target === userName) continue; if (!graph.has(s)) graph.set(s, new Set()); - if (!graph.has(o)) graph.set(o, new Set()); - graph.get(s).add(o); - graph.get(o).add(s); + if (!graph.has(target)) graph.set(target, new Set()); + graph.get(s).add(target); + graph.get(target).add(s); } return graph; @@ -347,6 +364,23 @@ function expandByFacts(presentEntities, facts, maxDepth = 2) { .map(([term]) => term); } +// ═══════════════════════════════════════════════════════════════════════════ +// 实体权重归一化(用于加分分配) +// ═══════════════════════════════════════════════════════════════════════════ + +function normalizeEntityWeights(queryEntityWeights) { + if (!queryEntityWeights?.size) return new Map(); + + const total = Array.from(queryEntityWeights.values()).reduce((a, b) => a + b, 0); + if (total <= 0) return new Map(); + + const normalized = new Map(); + for (const [entity, weight] of queryEntityWeights) { + normalized.set(entity, weight / total); + } + return normalized; +} + function stripFloorTag(s) { return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim(); } @@ -543,7 +577,7 @@ async function searchChunks(queryVector, vectorConfig, l0FloorBonus = new Map(), // L2 Events 检索(RRF 混合 + MMR 后置) // ═══════════════════════════════════════════════════════════════════════════ -async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus = new Map()) { +async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus = new Map()) { const { chatId } = getContext(); if (!chatId || !queryVector?.length) return []; @@ -567,7 +601,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo // 向量路检索(只保留 L0 加权) // ═══════════════════════════════════════════════════════════════════════ - const ENTITY_BONUS_FACTOR = 0.10; + const ENTITY_BONUS_POOL = 0.10; const scored = (allEvents || []).map((event, idx) => { const v = vectorMap.get(event.id); @@ -589,12 +623,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo const participants = (event.participants || []).map(p => normalize(p)); let maxEntityWeight = 0; for (const p of participants) { - const w = queryEntityWeights.get(p) || 0; + const w = normalizedEntityWeights.get(p) || 0; if (w > maxEntityWeight) { maxEntityWeight = w; } } - const entityBonus = ENTITY_BONUS_FACTOR * maxEntityWeight; + const entityBonus = ENTITY_BONUS_POOL * maxEntityWeight; bonus += entityBonus; return { @@ -605,10 +639,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo finalScore: sim + bonus, vector: v, _entityBonus: entityBonus, + _hasPresent: maxEntityWeight > 0, }; }); const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus])); + const hasPresentById = new Map(scored.map(s => [s._id, s._hasPresent])); const preFilterDistribution = { total: scored.length, @@ -654,7 +690,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo const results = mmrOutput.map(x => ({ event: x.event, similarity: x.rrf, - _recallType: x.type === 'HYBRID' ? 'DIRECT' : 'SIMILAR', + _recallType: hasPresentById.get(x.event?.id) ? 'DIRECT' : 'SIMILAR', _recallReason: x.type, _rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf }, _entityBonus: entityBonusById.get(x.event?.id) || 0, @@ -687,7 +723,7 @@ function formatRecallLog({ chunkResults, eventResults, allEvents, - queryEntityWeights = new Map(), + normalizedEntityWeights = new Map(), causalEvents = [], chunkPreFilterStats = null, l0Results = [], @@ -724,8 +760,8 @@ function formatRecallLog({ lines.push('\u2502 【提取实体】 \u2502'); lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518'); - if (queryEntityWeights?.size) { - const sorted = Array.from(queryEntityWeights.entries()) + if (normalizedEntityWeights?.size) { + const sorted = Array.from(normalizedEntityWeights.entries()) .sort((a, b) => b[1] - a[1]) .slice(0, 8); const formatted = sorted @@ -739,6 +775,20 @@ function formatRecallLog({ lines.push(` 扩散: ${expandedTerms.join('、')}`); } + lines.push(''); + lines.push(' 实体归一化(用于加分):'); + if (normalizedEntityWeights?.size) { + const sorted = Array.from(normalizedEntityWeights.entries()) + .sort((a, b) => b[1] - a[1]) + .slice(0, 8); + const formatted = sorted + .map(([e, w]) => `${e}(${(w * 100).toFixed(0)}%)`) + .join(' | '); + lines.push(` ${formatted}`); + } else { + lines.push(' (无)'); + } + lines.push(''); lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510'); lines.push('\u2502 【召回统计】 \u2502'); @@ -834,6 +884,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options = const queryEntities = Array.from(queryEntityWeights.keys()); const facts = getFacts(store); const expandedTerms = expandByFacts(queryEntities, facts, 2); + const normalizedEntityWeights = normalizeEntityWeights(queryEntityWeights); // 构建文本查询串:最后一条消息 + 实体 + 关键词 const lastSeg = segments[segments.length - 1] || ''; @@ -859,7 +910,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options = const [chunkResults, eventResults] = await Promise.all([ searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor), - searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus), + searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus), ]); const chunkPreFilterStats = chunkResults._preFilterStats || null; @@ -897,7 +948,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options = chunkResults: mergedChunks, eventResults, allEvents, - queryEntityWeights, + normalizedEntityWeights, causalEvents: causalEventsTruncated, chunkPreFilterStats, l0Results,