recall: normalize entity weights + relation target

This commit is contained in:
2026-02-03 01:13:57 +08:00
parent 1128d1494e
commit b0ed876cb0

View File

@@ -1,6 +1,7 @@
// Story Summary - Recall Engine
// L1 chunk + L2 event 召回
// - 全量向量打分
// - 实体权重归一化分配
// - 指数衰减加权 Query Embedding
// - L0 floor 加权
// - RRF 混合检索(向量 + 文本)
@@ -193,6 +194,19 @@ function cleanForRecall(text) {
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
}
// ═══════════════════════════════════════════════════════════════════════════
// 从谓词提取关系对象
// ═══════════════════════════════════════════════════════════════════════════
function extractRelationTarget(p) {
if (!p) return '';
let m = String(p).match(/^对(.+)的看法$/);
if (m) return m[1].trim();
m = String(p).match(/^与(.+)的关系$/);
if (m) return m[1].trim();
return '';
}
function buildExpDecayWeights(n, beta) {
const last = n - 1;
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
@@ -296,19 +310,22 @@ function buildEntityLexicon(store, allEvents) {
function buildFactGraph(facts) {
const graph = new Map();
const { name1 } = getContext();
const userName = normalize(name1);
for (const f of facts || []) {
if (f?.retracted) continue;
if (!isRelationFact(f)) continue;
const s = normalize(f?.s);
const o = normalize(f?.o);
if (!s || !o) continue;
const target = normalize(extractRelationTarget(f?.p));
if (!s || !target) continue;
if (s === userName || target === userName) continue;
if (!graph.has(s)) graph.set(s, new Set());
if (!graph.has(o)) graph.set(o, new Set());
graph.get(s).add(o);
graph.get(o).add(s);
if (!graph.has(target)) graph.set(target, new Set());
graph.get(s).add(target);
graph.get(target).add(s);
}
return graph;
@@ -347,6 +364,23 @@ function expandByFacts(presentEntities, facts, maxDepth = 2) {
.map(([term]) => term);
}
// ═══════════════════════════════════════════════════════════════════════════
// 实体权重归一化(用于加分分配)
// ═══════════════════════════════════════════════════════════════════════════
function normalizeEntityWeights(queryEntityWeights) {
if (!queryEntityWeights?.size) return new Map();
const total = Array.from(queryEntityWeights.values()).reduce((a, b) => a + b, 0);
if (total <= 0) return new Map();
const normalized = new Map();
for (const [entity, weight] of queryEntityWeights) {
normalized.set(entity, weight / total);
}
return normalized;
}
function stripFloorTag(s) {
return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim();
}
@@ -543,7 +577,7 @@ async function searchChunks(queryVector, vectorConfig, l0FloorBonus = new Map(),
// L2 Events 检索RRF 混合 + MMR 后置)
// ═══════════════════════════════════════════════════════════════════════════
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus = new Map()) {
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus = new Map()) {
const { chatId } = getContext();
if (!chatId || !queryVector?.length) return [];
@@ -567,7 +601,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
// 向量路检索(只保留 L0 加权)
// ═══════════════════════════════════════════════════════════════════════
const ENTITY_BONUS_FACTOR = 0.10;
const ENTITY_BONUS_POOL = 0.10;
const scored = (allEvents || []).map((event, idx) => {
const v = vectorMap.get(event.id);
@@ -589,12 +623,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
const participants = (event.participants || []).map(p => normalize(p));
let maxEntityWeight = 0;
for (const p of participants) {
const w = queryEntityWeights.get(p) || 0;
const w = normalizedEntityWeights.get(p) || 0;
if (w > maxEntityWeight) {
maxEntityWeight = w;
}
}
const entityBonus = ENTITY_BONUS_FACTOR * maxEntityWeight;
const entityBonus = ENTITY_BONUS_POOL * maxEntityWeight;
bonus += entityBonus;
return {
@@ -605,10 +639,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
finalScore: sim + bonus,
vector: v,
_entityBonus: entityBonus,
_hasPresent: maxEntityWeight > 0,
};
});
const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus]));
const hasPresentById = new Map(scored.map(s => [s._id, s._hasPresent]));
const preFilterDistribution = {
total: scored.length,
@@ -654,7 +690,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
const results = mmrOutput.map(x => ({
event: x.event,
similarity: x.rrf,
_recallType: x.type === 'HYBRID' ? 'DIRECT' : 'SIMILAR',
_recallType: hasPresentById.get(x.event?.id) ? 'DIRECT' : 'SIMILAR',
_recallReason: x.type,
_rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf },
_entityBonus: entityBonusById.get(x.event?.id) || 0,
@@ -687,7 +723,7 @@ function formatRecallLog({
chunkResults,
eventResults,
allEvents,
queryEntityWeights = new Map(),
normalizedEntityWeights = new Map(),
causalEvents = [],
chunkPreFilterStats = null,
l0Results = [],
@@ -724,8 +760,8 @@ function formatRecallLog({
lines.push('\u2502 【提取实体】 \u2502');
lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518');
if (queryEntityWeights?.size) {
const sorted = Array.from(queryEntityWeights.entries())
if (normalizedEntityWeights?.size) {
const sorted = Array.from(normalizedEntityWeights.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 8);
const formatted = sorted
@@ -739,6 +775,20 @@ function formatRecallLog({
lines.push(` 扩散: ${expandedTerms.join('、')}`);
}
lines.push('');
lines.push(' 实体归一化(用于加分):');
if (normalizedEntityWeights?.size) {
const sorted = Array.from(normalizedEntityWeights.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 8);
const formatted = sorted
.map(([e, w]) => `${e}(${(w * 100).toFixed(0)}%)`)
.join(' | ');
lines.push(` ${formatted}`);
} else {
lines.push(' (无)');
}
lines.push('');
lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510');
lines.push('\u2502 【召回统计】 \u2502');
@@ -834,6 +884,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
const queryEntities = Array.from(queryEntityWeights.keys());
const facts = getFacts(store);
const expandedTerms = expandByFacts(queryEntities, facts, 2);
const normalizedEntityWeights = normalizeEntityWeights(queryEntityWeights);
// 构建文本查询串:最后一条消息 + 实体 + 关键词
const lastSeg = segments[segments.length - 1] || '';
@@ -859,7 +910,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
const [chunkResults, eventResults] = await Promise.all([
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus),
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus),
]);
const chunkPreFilterStats = chunkResults._preFilterStats || null;
@@ -897,7 +948,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
chunkResults: mergedChunks,
eventResults,
allEvents,
queryEntityWeights,
normalizedEntityWeights,
causalEvents: causalEventsTruncated,
chunkPreFilterStats,
l0Results,