autostream-ai-agent/backend/rag.py
2026-04-12 21:03:52 +05:30

156 lines
5.9 KiB
Python

"""
rag.py — Retrieval-Augmented Generation pipeline for AutoStream Agent
Loads knowledge_base.json, chunks it, embeds with sentence-transformers,
indexes with FAISS, and retrieves top-k relevant chunks at query time.
"""
import json
import logging
import os
import re
from pathlib import Path
from typing import Optional
logger = logging.getLogger("autostream.rag")
# ── Optional heavy imports (graceful fallback) ─────────────────────────────────
try:
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
_FAISS_AVAILABLE = True
except ImportError:
_FAISS_AVAILABLE = False
logger.warning("FAISS / sentence-transformers not installed. Falling back to keyword RAG.")
KB_PATH = Path(__file__).parent / "knowledge_base.json"
# ── Knowledge base loader ──────────────────────────────────────────────────────
def load_knowledge_base(path: Path = KB_PATH) -> dict:
if not path.exists():
raise FileNotFoundError(f"knowledge_base.json not found at {path}")
with open(path) as f:
return json.load(f)
def chunk_knowledge_base(data: dict) -> list[dict]:
"""
Convert the JSON knowledge base into a flat list of text chunks,
each with a source label.
"""
chunks = []
# Company overview
chunks.append({
"id": "overview",
"source": "Company Overview",
"text": f"{data['company']}{data['description']}",
})
# Plans
for plan_key, plan in data.get("plans", {}).items():
features = ", ".join(plan.get("features", []))
chunks.append({
"id": f"plan_{plan_key}",
"source": f"{plan_key.title()} Plan",
"text": (
f"{plan_key.title()} Plan costs {plan['price']}. "
f"Features include: {features}."
),
})
# Policies
for policy_key, policy_val in data.get("policies", {}).items():
chunks.append({
"id": f"policy_{policy_key}",
"source": f"Policy: {policy_key.replace('_', ' ').title()}",
"text": f"{policy_key.replace('_', ' ').title()}: {policy_val}",
})
# FAQ (if present)
for faq in data.get("faq", []):
chunks.append({
"id": f"faq_{len(chunks)}",
"source": "FAQ",
"text": f"Q: {faq['question']} A: {faq['answer']}",
})
logger.info(f"Knowledge base chunked into {len(chunks)} segments.")
return chunks
# ── FAISS-backed retriever ─────────────────────────────────────────────────────
class FAISSRetriever:
def __init__(self, chunks: list[dict], model_name: str = "all-MiniLM-L6-v2"):
self.chunks = chunks
self.model = SentenceTransformer(model_name)
texts = [c["text"] for c in chunks]
embeddings = self.model.encode(texts, convert_to_numpy=True)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.astype("float32"))
logger.info(f"FAISS index built with {len(chunks)} chunks.")
def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
q_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
distances, indices = self.index.search(q_vec, top_k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
results.append({**self.chunks[idx], "score": float(dist)})
logger.debug(f"RAG retrieved {len(results)} chunks for: '{query}'")
return results
# ── Keyword fallback retriever ─────────────────────────────────────────────────
class KeywordRetriever:
def __init__(self, chunks: list[dict]):
self.chunks = chunks
def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
tokens = set(re.findall(r"\w+", query.lower()))
scored = []
for chunk in self.chunks:
chunk_tokens = set(re.findall(r"\w+", chunk["text"].lower()))
score = len(tokens & chunk_tokens)
scored.append((score, chunk))
scored.sort(key=lambda x: x[0], reverse=True)
results = [c for _, c in scored[:top_k]]
logger.debug(f"Keyword RAG retrieved {len(results)} chunks for: '{query}'")
return results
# ── Public RAG interface ───────────────────────────────────────────────────────
class RAGPipeline:
def __init__(self):
data = load_knowledge_base()
self.raw_data = data
self.chunks = chunk_knowledge_base(data)
if _FAISS_AVAILABLE:
try:
self.retriever = FAISSRetriever(self.chunks)
logger.info("Using FAISS semantic retriever.")
except Exception as e:
logger.warning(f"FAISS init failed ({e}), falling back to keyword.")
self.retriever = KeywordRetriever(self.chunks)
else:
self.retriever = KeywordRetriever(self.chunks)
logger.info("Using keyword-based retriever (FAISS not available).")
def get_context(self, query: str, top_k: int = 3) -> str:
"""Return a formatted context string for the LLM prompt."""
hits = self.retriever.retrieve(query, top_k=top_k)
if not hits:
return "No specific information found in the knowledge base."
parts = [f"[{h['source']}] {h['text']}" for h in hits]
return "\n".join(parts)
def full_context(self) -> str:
"""Return ALL chunks as a single context string (for grounding)."""
return "\n".join(c["text"] for c in self.chunks)