recall: normalize entity weights + relation target
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
// Story Summary - Recall Engine
|
// Story Summary - Recall Engine
|
||||||
// L1 chunk + L2 event 召回
|
// L1 chunk + L2 event 召回
|
||||||
// - 全量向量打分
|
// - 全量向量打分
|
||||||
|
// - 实体权重归一化分配
|
||||||
// - 指数衰减加权 Query Embedding
|
// - 指数衰减加权 Query Embedding
|
||||||
// - L0 floor 加权
|
// - L0 floor 加权
|
||||||
// - RRF 混合检索(向量 + 文本)
|
// - RRF 混合检索(向量 + 文本)
|
||||||
@@ -193,6 +194,19 @@ function cleanForRecall(text) {
|
|||||||
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
|
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
// 从谓词提取关系对象
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
function extractRelationTarget(p) {
|
||||||
|
if (!p) return '';
|
||||||
|
let m = String(p).match(/^对(.+)的看法$/);
|
||||||
|
if (m) return m[1].trim();
|
||||||
|
m = String(p).match(/^与(.+)的关系$/);
|
||||||
|
if (m) return m[1].trim();
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
function buildExpDecayWeights(n, beta) {
|
function buildExpDecayWeights(n, beta) {
|
||||||
const last = n - 1;
|
const last = n - 1;
|
||||||
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
|
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
|
||||||
@@ -296,19 +310,22 @@ function buildEntityLexicon(store, allEvents) {
|
|||||||
|
|
||||||
function buildFactGraph(facts) {
|
function buildFactGraph(facts) {
|
||||||
const graph = new Map();
|
const graph = new Map();
|
||||||
|
const { name1 } = getContext();
|
||||||
|
const userName = normalize(name1);
|
||||||
|
|
||||||
for (const f of facts || []) {
|
for (const f of facts || []) {
|
||||||
if (f?.retracted) continue;
|
if (f?.retracted) continue;
|
||||||
if (!isRelationFact(f)) continue;
|
if (!isRelationFact(f)) continue;
|
||||||
|
|
||||||
const s = normalize(f?.s);
|
const s = normalize(f?.s);
|
||||||
const o = normalize(f?.o);
|
const target = normalize(extractRelationTarget(f?.p));
|
||||||
if (!s || !o) continue;
|
if (!s || !target) continue;
|
||||||
|
if (s === userName || target === userName) continue;
|
||||||
|
|
||||||
if (!graph.has(s)) graph.set(s, new Set());
|
if (!graph.has(s)) graph.set(s, new Set());
|
||||||
if (!graph.has(o)) graph.set(o, new Set());
|
if (!graph.has(target)) graph.set(target, new Set());
|
||||||
graph.get(s).add(o);
|
graph.get(s).add(target);
|
||||||
graph.get(o).add(s);
|
graph.get(target).add(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
@@ -347,6 +364,23 @@ function expandByFacts(presentEntities, facts, maxDepth = 2) {
|
|||||||
.map(([term]) => term);
|
.map(([term]) => term);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
// 实体权重归一化(用于加分分配)
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
function normalizeEntityWeights(queryEntityWeights) {
|
||||||
|
if (!queryEntityWeights?.size) return new Map();
|
||||||
|
|
||||||
|
const total = Array.from(queryEntityWeights.values()).reduce((a, b) => a + b, 0);
|
||||||
|
if (total <= 0) return new Map();
|
||||||
|
|
||||||
|
const normalized = new Map();
|
||||||
|
for (const [entity, weight] of queryEntityWeights) {
|
||||||
|
normalized.set(entity, weight / total);
|
||||||
|
}
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
function stripFloorTag(s) {
|
function stripFloorTag(s) {
|
||||||
return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim();
|
return String(s || '').replace(/\s*\(#\d+(?:-\d+)?\)\s*$/, '').trim();
|
||||||
}
|
}
|
||||||
@@ -543,7 +577,7 @@ async function searchChunks(queryVector, vectorConfig, l0FloorBonus = new Map(),
|
|||||||
// L2 Events 检索(RRF 混合 + MMR 后置)
|
// L2 Events 检索(RRF 混合 + MMR 后置)
|
||||||
// ═══════════════════════════════════════════════════════════════════════════
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus = new Map()) {
|
async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus = new Map()) {
|
||||||
const { chatId } = getContext();
|
const { chatId } = getContext();
|
||||||
if (!chatId || !queryVector?.length) return [];
|
if (!chatId || !queryVector?.length) return [];
|
||||||
|
|
||||||
@@ -567,7 +601,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
|||||||
// 向量路检索(只保留 L0 加权)
|
// 向量路检索(只保留 L0 加权)
|
||||||
// ═══════════════════════════════════════════════════════════════════════
|
// ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
const ENTITY_BONUS_FACTOR = 0.10;
|
const ENTITY_BONUS_POOL = 0.10;
|
||||||
|
|
||||||
const scored = (allEvents || []).map((event, idx) => {
|
const scored = (allEvents || []).map((event, idx) => {
|
||||||
const v = vectorMap.get(event.id);
|
const v = vectorMap.get(event.id);
|
||||||
@@ -589,12 +623,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
|||||||
const participants = (event.participants || []).map(p => normalize(p));
|
const participants = (event.participants || []).map(p => normalize(p));
|
||||||
let maxEntityWeight = 0;
|
let maxEntityWeight = 0;
|
||||||
for (const p of participants) {
|
for (const p of participants) {
|
||||||
const w = queryEntityWeights.get(p) || 0;
|
const w = normalizedEntityWeights.get(p) || 0;
|
||||||
if (w > maxEntityWeight) {
|
if (w > maxEntityWeight) {
|
||||||
maxEntityWeight = w;
|
maxEntityWeight = w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const entityBonus = ENTITY_BONUS_FACTOR * maxEntityWeight;
|
const entityBonus = ENTITY_BONUS_POOL * maxEntityWeight;
|
||||||
bonus += entityBonus;
|
bonus += entityBonus;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -605,10 +639,12 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
|||||||
finalScore: sim + bonus,
|
finalScore: sim + bonus,
|
||||||
vector: v,
|
vector: v,
|
||||||
_entityBonus: entityBonus,
|
_entityBonus: entityBonus,
|
||||||
|
_hasPresent: maxEntityWeight > 0,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus]));
|
const entityBonusById = new Map(scored.map(s => [s._id, s._entityBonus]));
|
||||||
|
const hasPresentById = new Map(scored.map(s => [s._id, s._hasPresent]));
|
||||||
|
|
||||||
const preFilterDistribution = {
|
const preFilterDistribution = {
|
||||||
total: scored.length,
|
total: scored.length,
|
||||||
@@ -654,7 +690,7 @@ async function searchEvents(queryVector, queryTextForSearch, allEvents, vectorCo
|
|||||||
const results = mmrOutput.map(x => ({
|
const results = mmrOutput.map(x => ({
|
||||||
event: x.event,
|
event: x.event,
|
||||||
similarity: x.rrf,
|
similarity: x.rrf,
|
||||||
_recallType: x.type === 'HYBRID' ? 'DIRECT' : 'SIMILAR',
|
_recallType: hasPresentById.get(x.event?.id) ? 'DIRECT' : 'SIMILAR',
|
||||||
_recallReason: x.type,
|
_recallReason: x.type,
|
||||||
_rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf },
|
_rrfDetail: { vRank: x.vRank, tRank: x.tRank, rrf: x.rrf },
|
||||||
_entityBonus: entityBonusById.get(x.event?.id) || 0,
|
_entityBonus: entityBonusById.get(x.event?.id) || 0,
|
||||||
@@ -687,7 +723,7 @@ function formatRecallLog({
|
|||||||
chunkResults,
|
chunkResults,
|
||||||
eventResults,
|
eventResults,
|
||||||
allEvents,
|
allEvents,
|
||||||
queryEntityWeights = new Map(),
|
normalizedEntityWeights = new Map(),
|
||||||
causalEvents = [],
|
causalEvents = [],
|
||||||
chunkPreFilterStats = null,
|
chunkPreFilterStats = null,
|
||||||
l0Results = [],
|
l0Results = [],
|
||||||
@@ -724,8 +760,8 @@ function formatRecallLog({
|
|||||||
lines.push('\u2502 【提取实体】 \u2502');
|
lines.push('\u2502 【提取实体】 \u2502');
|
||||||
lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518');
|
lines.push('\u2514' + '\u2500'.repeat(61) + '\u2518');
|
||||||
|
|
||||||
if (queryEntityWeights?.size) {
|
if (normalizedEntityWeights?.size) {
|
||||||
const sorted = Array.from(queryEntityWeights.entries())
|
const sorted = Array.from(normalizedEntityWeights.entries())
|
||||||
.sort((a, b) => b[1] - a[1])
|
.sort((a, b) => b[1] - a[1])
|
||||||
.slice(0, 8);
|
.slice(0, 8);
|
||||||
const formatted = sorted
|
const formatted = sorted
|
||||||
@@ -739,6 +775,20 @@ function formatRecallLog({
|
|||||||
lines.push(` 扩散: ${expandedTerms.join('、')}`);
|
lines.push(` 扩散: ${expandedTerms.join('、')}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lines.push('');
|
||||||
|
lines.push(' 实体归一化(用于加分):');
|
||||||
|
if (normalizedEntityWeights?.size) {
|
||||||
|
const sorted = Array.from(normalizedEntityWeights.entries())
|
||||||
|
.sort((a, b) => b[1] - a[1])
|
||||||
|
.slice(0, 8);
|
||||||
|
const formatted = sorted
|
||||||
|
.map(([e, w]) => `${e}(${(w * 100).toFixed(0)}%)`)
|
||||||
|
.join(' | ');
|
||||||
|
lines.push(` ${formatted}`);
|
||||||
|
} else {
|
||||||
|
lines.push(' (无)');
|
||||||
|
}
|
||||||
|
|
||||||
lines.push('');
|
lines.push('');
|
||||||
lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510');
|
lines.push('\u250c' + '\u2500'.repeat(61) + '\u2510');
|
||||||
lines.push('\u2502 【召回统计】 \u2502');
|
lines.push('\u2502 【召回统计】 \u2502');
|
||||||
@@ -834,6 +884,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
|||||||
const queryEntities = Array.from(queryEntityWeights.keys());
|
const queryEntities = Array.from(queryEntityWeights.keys());
|
||||||
const facts = getFacts(store);
|
const facts = getFacts(store);
|
||||||
const expandedTerms = expandByFacts(queryEntities, facts, 2);
|
const expandedTerms = expandByFacts(queryEntities, facts, 2);
|
||||||
|
const normalizedEntityWeights = normalizeEntityWeights(queryEntityWeights);
|
||||||
|
|
||||||
// 构建文本查询串:最后一条消息 + 实体 + 关键词
|
// 构建文本查询串:最后一条消息 + 实体 + 关键词
|
||||||
const lastSeg = segments[segments.length - 1] || '';
|
const lastSeg = segments[segments.length - 1] || '';
|
||||||
@@ -859,7 +910,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
|||||||
|
|
||||||
const [chunkResults, eventResults] = await Promise.all([
|
const [chunkResults, eventResults] = await Promise.all([
|
||||||
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
|
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
|
||||||
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, queryEntityWeights, l0FloorBonus),
|
searchEvents(queryVector, queryTextForSearch, allEvents, vectorConfig, store, normalizedEntityWeights, l0FloorBonus),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const chunkPreFilterStats = chunkResults._preFilterStats || null;
|
const chunkPreFilterStats = chunkResults._preFilterStats || null;
|
||||||
@@ -897,7 +948,7 @@ export async function recallMemory(queryText, allEvents, vectorConfig, options =
|
|||||||
chunkResults: mergedChunks,
|
chunkResults: mergedChunks,
|
||||||
eventResults,
|
eventResults,
|
||||||
allEvents,
|
allEvents,
|
||||||
queryEntityWeights,
|
normalizedEntityWeights,
|
||||||
causalEvents: causalEventsTruncated,
|
causalEvents: causalEventsTruncated,
|
||||||
chunkPreFilterStats,
|
chunkPreFilterStats,
|
||||||
l0Results,
|
l0Results,
|
||||||
|
|||||||
Reference in New Issue
Block a user