recall: normalize entity weights + relation target
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// Story Summary - Recall Engine
|
||||
// L1 chunk + L2 event 召回
|
||||
// - 全量向量打分
|
||||
// - 实体权重归一化分配
|
||||
// - 指数衰减加权 Query Embedding
|
||||
// - L0 floor 加权
|
||||
// - RRF 混合检索(向量 + 文本)
|
||||
@@ -193,6 +194,19 @@ function cleanForRecall(text) {
|
||||
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// 从谓词提取关系对象
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
function extractRelationTarget(p) {
|
||||
if (!p) return '';
|
||||
let m = String(p).match(/^对(.+)的看法$/);
|
||||
if (m) return m[1].trim();
|
||||
m = String(p).match(/^与(.+)的关系$/);
|
||||
if (m) return m[1].trim();
|
||||
return '';
|
||||
}
|
||||
|
||||
function buildExpDecayWeights(n, beta) {
|
||||
const last = n - 1;
|
||||
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
|
||||
@@ -296,19 +310,22 @@ function buildEntityLexicon(store, allEvents) {
|
||||
|
||||
function buildFactGraph(facts) {
|
||||
const graph = new Map();
|
||||
const { name1 } = getContext();
|
||||
const userName = normalize(name1);
|
||||
|
||||
for (const f of facts || []) {
|
||||
if (f?.retracted) continue;
|
||||
if (!isRelationFact(f)) continue;
|
||||
|
||||
const s = normalize(f?.s);
|
||||
const o = normalize(f?.o);
|
||||
if (!s || !o) continue;
|
||||
const target = normalize(extractRelationTarget(f?.p));
|
||||
if (!s || !target) continue;
|
||||
if (s === userName || target === userName) continue;
|
||||
|
||||
if (!graph.has(s)) graph.set(s, new Set());
|
||||
if (!graph.has(o)) graph.set(o, new Set());
|
||||
graph.get(s).add(o);
|
||||
graph.get(o).add(s);
|
||||
if (!graph.has(target)) graph.set(target, new Set());
|
||||
graph.get(s).add(target);
|
||||
graph.get(target).add(s);
|
||||
}
|
||||
|
||||
return graph;
|
||||
@@ -347,6 +364,23 @@ function expandByFacts(presentEntities, facts, maxDepth = 2) {
|
||||
.map(([term]) => term);
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// 实体权重归一化(用于加分分配)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
function normalizeEntityWeights(queryEntityWeights) {
|
||||
if (!queryEntityWeights?.size) return new Map();
|
||||
|
||||
const total = Array.from(queryEntityWeights.values()).reduce((a, b) => a + b, 0);
|
||||
if (total <= 0) return new Map();
|
||||
|
||||
const normalized = new Map();
|
||||
for (const [entity, weight] of queryEntityWeights) {
|
||||
normalized.set(entity, weight / total);
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function stripFloorTag(s) {
|
||||
return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim();
|
||||
}
|
||||
@@ -543,7 +577,7 @@ async function searchChunks(queryVector, vectorConfig, l0FloorBonus = new Map(),
|
||||
// L2 Events 检索(RRF 混合 + MMR 后置)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus = new Map()) {
|
||||
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus = new Map()) {
|
||||
const { chatId } = getContext();
|
||||
if (!chatId || !queryVector?.length) return [];
|
||||
|
||||
@@ -567,7 +601,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
||||
// 向量路检索(只保留 L0 加权)
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
const ENTITY_BONUS_FACTOR = 0.10;
|
||||
const ENTITY_BONUS_POOL = 0.10;
|
||||
|
||||
const scored = (allEvents || []).map((event, idx) => {
|
||||
const v = vectorMap.get(event.id);
|
||||
@@ -589,12 +623,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
||||
const participants = (event.participants || []).map(p => normalize(p));
|
||||
let maxEntityWeight = 0;
|
||||
for (const p of participants) {
|
||||
const w = queryEntityWeights.get(p) || 0;
|
||||
const w = normalizedEntityWeights.get(p) || 0;
|
||||
if (w > maxEntityWeight) {
|
||||
maxEntityWeight = w;
|
||||
}
|
||||
}
|
||||
const entityBonus = ENTITY_BONUS_FACTOR * maxEntityWeight;
|
||||
const entityBonus = ENTITY_BONUS_POOL * maxEntityWeight;
|
||||
bonus += entityBonus;
|
||||
|
||||
return {
|
||||
@@ -605,10 +639,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
||||
finalScore: sim + bonus,
|
||||
vector: v,
|
||||
_entityBonus: entityBonus,
|
||||
_hasPresent: maxEntityWeight > 0,
|
||||
};
|
||||
});
|
||||
|
||||
const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus]));
|
||||
const hasPresentById = new Map(scored.map(s => [s._id, s._hasPresent]));
|
||||
|
||||
const preFilterDistribution = {
|
||||
total: scored.length,
|
||||
@@ -654,7 +690,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
||||
const results = mmrOutput.map(x => ({
|
||||
event: x.event,
|
||||
similarity: x.rrf,
|
||||
_recallType: x.type === 'HYBRID' ? 'DIRECT' : 'SIMILAR',
|
||||
_recallType: hasPresentById.get(x.event?.id) ? 'DIRECT' : 'SIMILAR',
|
||||
_recallReason: x.type,
|
||||
_rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf },
|
||||
_entityBonus: entityBonusById.get(x.event?.id) || 0,
|
||||
@@ -687,7 +723,7 @@ function formatRecallLog({
|
||||
chunkResults,
|
||||
eventResults,
|
||||
allEvents,
|
||||
queryEntityWeights = new Map(),
|
||||
normalizedEntityWeights = new Map(),
|
||||
causalEvents = [],
|
||||
chunkPreFilterStats = null,
|
||||
l0Results = [],
|
||||
@@ -724,8 +760,8 @@ function formatRecallLog({
|
||||
lines.push('\u2502 【提取实体】 \u2502');
|
||||
lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518');
|
||||
|
||||
if (queryEntityWeights?.size) {
|
||||
const sorted = Array.from(queryEntityWeights.entries())
|
||||
if (normalizedEntityWeights?.size) {
|
||||
const sorted = Array.from(normalizedEntityWeights.entries())
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.slice(0, 8);
|
||||
const formatted = sorted
|
||||
@@ -739,6 +775,20 @@ function formatRecallLog({
|
||||
lines.push(` 扩散: ${expandedTerms.join('、')}`);
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push(' 实体归一化(用于加分):');
|
||||
if (normalizedEntityWeights?.size) {
|
||||
const sorted = Array.from(normalizedEntityWeights.entries())
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.slice(0, 8);
|
||||
const formatted = sorted
|
||||
.map(([e, w]) => `${e}(${(w * 100).toFixed(0)}%)`)
|
||||
.join(' | ');
|
||||
lines.push(` ${formatted}`);
|
||||
} else {
|
||||
lines.push(' (无)');
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510');
|
||||
lines.push('\u2502 【召回统计】 \u2502');
|
||||
@@ -834,6 +884,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
||||
const queryEntities = Array.from(queryEntityWeights.keys());
|
||||
const facts = getFacts(store);
|
||||
const expandedTerms = expandByFacts(queryEntities, facts, 2);
|
||||
const normalizedEntityWeights = normalizeEntityWeights(queryEntityWeights);
|
||||
|
||||
// 构建文本查询串:最后一条消息 + 实体 + 关键词
|
||||
const lastSeg = segments[segments.length - 1] || '';
|
||||
@@ -859,7 +910,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
||||
|
||||
const [chunkResults, eventResults] = await Promise.all([
|
||||
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
|
||||
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus),
|
||||
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus),
|
||||
]);
|
||||
|
||||
const chunkPreFilterStats = chunkResults._preFilterStats || null;
|
||||
@@ -897,7 +948,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
||||
chunkResults: mergedChunks,
|
||||
eventResults,
|
||||
allEvents,
|
||||
queryEntityWeights,
|
||||
normalizedEntityWeights,
|
||||
causalEvents: causalEventsTruncated,
|
||||
chunkPreFilterStats,
|
||||
l0Results,
|
||||
|
||||
Reference in New Issue
Block a user