recall: normalize entity weights + relation target

This commit is contained in:
2026-02-03 01:13:57 +08:00
parent 1128d1494e
commit b0ed876cb0

View File

@@ -1,6 +1,7 @@
// Story Summary - Recall Engine // Story Summary - Recall Engine
// L1 chunk + L2 event 召回 // L1 chunk + L2 event 召回
// - 全量向量打分 // - 全量向量打分
// - 实体权重归一化分配
// - 指数衰减加权 Query Embedding // - 指数衰减加权 Query Embedding
// - L0 floor 加权 // - L0 floor 加权
// - RRF 混合检索(向量 + 文本) // - RRF 混合检索(向量 + 文本)
@@ -193,6 +194,19 @@ function cleanForRecall(text) {
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim(); return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
} }
// ═══════════════════════════════════════════════════════════════════════════
// 从谓词提取关系对象
// ═══════════════════════════════════════════════════════════════════════════
function extractRelationTarget(p) {
if (!p) return '';
let m = String(p).match(/^对(.+)的看法$/);
if (m) return m[1].trim();
m = String(p).match(/^与(.+)的关系$/);
if (m) return m[1].trim();
return '';
}
function buildExpDecayWeights(n, beta) { function buildExpDecayWeights(n, beta) {
const last = n - 1; const last = n - 1;
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last))); const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
@@ -296,19 +310,22 @@ function buildEntityLexicon(store, allEvents) {
function buildFactGraph(facts) { function buildFactGraph(facts) {
const graph = new Map(); const graph = new Map();
const { name1 } = getContext();
const userName = normalize(name1);
for (const f of facts || []) { for (const f of facts || []) {
if (f?.retracted) continue; if (f?.retracted) continue;
if (!isRelationFact(f)) continue; if (!isRelationFact(f)) continue;
const s = normalize(f?.s); const s = normalize(f?.s);
const o = normalize(f?.o); const target = normalize(extractRelationTarget(f?.p));
if (!s || !o) continue; if (!s || !target) continue;
if (s === userName || target === userName) continue;
if (!graph.has(s)) graph.set(s, new Set()); if (!graph.has(s)) graph.set(s, new Set());
if (!graph.has(o)) graph.set(o, new Set()); if (!graph.has(target)) graph.set(target, new Set());
graph.get(s).add(o); graph.get(s).add(target);
graph.get(o).add(s); graph.get(target).add(s);
} }
return graph; return graph;
@@ -347,6 +364,23 @@ function expandByFacts(presentEntities, facts, maxDepth = 2) {
.map(([term]) => term); .map(([term]) => term);
} }
// ═══════════════════════════════════════════════════════════════════════════
// 实体权重归一化(用于加分分配)
// ═══════════════════════════════════════════════════════════════════════════
function normalizeEntityWeights(queryEntityWeights) {
if (!queryEntityWeights?.size) return new Map();
const total = Array.from(queryEntityWeights.values()).reduce((a, b) => a + b, 0);
if (total <= 0) return new Map();
const normalized = new Map();
for (const [entity, weight] of queryEntityWeights) {
normalized.set(entity, weight / total);
}
return normalized;
}
function stripFloorTag(s) { function stripFloorTag(s) {
return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim(); return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim();
} }
@@ -543,7 +577,7 @@ async function searchChunks(queryVector, vectorConfig, l0FloorBonus = new Map(),
// L2 Events 检索RRF 混合 + MMR 后置) // L2 Events 检索RRF 混合 + MMR 后置)
// ═══════════════════════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════════════════════
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus = new Map()) { async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus = new Map()) {
const { chatId } = getContext(); const { chatId } = getContext();
if (!chatId || !queryVector?.length) return []; if (!chatId || !queryVector?.length) return [];
@@ -567,7 +601,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
// 向量路检索(只保留 L0 加权) // 向量路检索(只保留 L0 加权)
// ═══════════════════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════════════════
const ENTITY_BONUS_FACTOR = 0.10; const ENTITY_BONUS_POOL = 0.10;
const scored = (allEvents || []).map((event, idx) => { const scored = (allEvents || []).map((event, idx) => {
const v = vectorMap.get(event.id); const v = vectorMap.get(event.id);
@@ -589,12 +623,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
const participants = (event.participants || []).map(p => normalize(p)); const participants = (event.participants || []).map(p => normalize(p));
let maxEntityWeight = 0; let maxEntityWeight = 0;
for (const p of participants) { for (const p of participants) {
const w = queryEntityWeights.get(p) || 0; const w = normalizedEntityWeights.get(p) || 0;
if (w > maxEntityWeight) { if (w > maxEntityWeight) {
maxEntityWeight = w; maxEntityWeight = w;
} }
} }
const entityBonus = ENTITY_BONUS_FACTOR * maxEntityWeight; const entityBonus = ENTITY_BONUS_POOL * maxEntityWeight;
bonus += entityBonus; bonus += entityBonus;
return { return {
@@ -605,10 +639,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
finalScore: sim + bonus, finalScore: sim + bonus,
vector: v, vector: v,
_entityBonus: entityBonus, _entityBonus: entityBonus,
_hasPresent: maxEntityWeight > 0,
}; };
}); });
const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus])); const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus]));
const hasPresentById = new Map(scored.map(s => [s._id, s._hasPresent]));
const preFilterDistribution = { const preFilterDistribution = {
total: scored.length, total: scored.length,
@@ -654,7 +690,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
const results = mmrOutput.map(x => ({ const results = mmrOutput.map(x => ({
event: x.event, event: x.event,
similarity: x.rrf, similarity: x.rrf,
_recallType: x.type === 'HYBRID' ? 'DIRECT' : 'SIMILAR', _recallType: hasPresentById.get(x.event?.id) ? 'DIRECT' : 'SIMILAR',
_recallReason: x.type, _recallReason: x.type,
_rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf }, _rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf },
_entityBonus: entityBonusById.get(x.event?.id) || 0, _entityBonus: entityBonusById.get(x.event?.id) || 0,
@@ -687,7 +723,7 @@ function formatRecallLog({
chunkResults, chunkResults,
eventResults, eventResults,
allEvents, allEvents,
queryEntityWeights = new Map(), normalizedEntityWeights = new Map(),
causalEvents = [], causalEvents = [],
chunkPreFilterStats = null, chunkPreFilterStats = null,
l0Results = [], l0Results = [],
@@ -724,8 +760,8 @@ function formatRecallLog({
lines.push('\u2502 【提取实体】 \u2502'); lines.push('\u2502 【提取实体】 \u2502');
lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518'); lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518');
if (queryEntityWeights?.size) { if (normalizedEntityWeights?.size) {
const sorted = Array.from(queryEntityWeights.entries()) const sorted = Array.from(normalizedEntityWeights.entries())
.sort((a, b) => b[1] - a[1]) .sort((a, b) => b[1] - a[1])
.slice(0, 8); .slice(0, 8);
const formatted = sorted const formatted = sorted
@@ -739,6 +775,20 @@ function formatRecallLog({
lines.push(` 扩散: ${expandedTerms.join('、')}`); lines.push(` 扩散: ${expandedTerms.join('、')}`);
} }
lines.push('');
lines.push(' 实体归一化(用于加分):');
if (normalizedEntityWeights?.size) {
const sorted = Array.from(normalizedEntityWeights.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 8);
const formatted = sorted
.map(([e, w]) => `${e}(${(w * 100).toFixed(0)}%)`)
.join(' | ');
lines.push(` ${formatted}`);
} else {
lines.push(' (无)');
}
lines.push(''); lines.push('');
lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510'); lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510');
lines.push('\u2502 【召回统计】 \u2502'); lines.push('\u2502 【召回统计】 \u2502');
@@ -834,6 +884,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
const queryEntities = Array.from(queryEntityWeights.keys()); const queryEntities = Array.from(queryEntityWeights.keys());
const facts = getFacts(store); const facts = getFacts(store);
const expandedTerms = expandByFacts(queryEntities, facts, 2); const expandedTerms = expandByFacts(queryEntities, facts, 2);
const normalizedEntityWeights = normalizeEntityWeights(queryEntityWeights);
// 构建文本查询串:最后一条消息 + 实体 + 关键词 // 构建文本查询串:最后一条消息 + 实体 + 关键词
const lastSeg = segments[segments.length - 1] || ''; const lastSeg = segments[segments.length - 1] || '';
@@ -859,7 +910,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
const [chunkResults, eventResults] = await Promise.all([ const [chunkResults, eventResults] = await Promise.all([
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor), searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus), searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus),
]); ]);
const chunkPreFilterStats = chunkResults._preFilterStats || null; const chunkPreFilterStats = chunkResults._preFilterStats || null;
@@ -897,7 +948,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
chunkResults: mergedChunks, chunkResults: mergedChunks,
eventResults, eventResults,
allEvents, allEvents,
queryEntityWeights, normalizedEntityWeights,
causalEvents: causalEventsTruncated, causalEvents: causalEventsTruncated,
chunkPreFilterStats, chunkPreFilterStats,
l0Results, l0Results,