mirror of
https://github.com/Kartvaya2008/autostream-ai-agent
synced 2026-04-21 15:47:55 +00:00
156 lines
5.9 KiB
Python
156 lines
5.9 KiB
Python
"""
|
|
rag.py — Retrieval-Augmented Generation pipeline for AutoStream Agent
|
|
Loads knowledge_base.json, chunks it, embeds with sentence-transformers,
|
|
indexes with FAISS, and retrieves top-k relevant chunks at query time.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger("autostream.rag")
|
|
|
|
# ── Optional heavy imports (graceful fallback) ─────────────────────────────────
|
|
try:
|
|
import numpy as np
|
|
import faiss
|
|
from sentence_transformers import SentenceTransformer
|
|
_FAISS_AVAILABLE = True
|
|
except ImportError:
|
|
_FAISS_AVAILABLE = False
|
|
logger.warning("FAISS / sentence-transformers not installed. Falling back to keyword RAG.")
|
|
|
|
|
|
KB_PATH = Path(__file__).parent / "knowledge_base.json"
|
|
|
|
|
|
# ── Knowledge base loader ──────────────────────────────────────────────────────
|
|
|
|
def load_knowledge_base(path: Path = KB_PATH) -> dict:
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"knowledge_base.json not found at {path}")
|
|
with open(path) as f:
|
|
return json.load(f)
|
|
|
|
|
|
def chunk_knowledge_base(data: dict) -> list[dict]:
|
|
"""
|
|
Convert the JSON knowledge base into a flat list of text chunks,
|
|
each with a source label.
|
|
"""
|
|
chunks = []
|
|
|
|
# Company overview
|
|
chunks.append({
|
|
"id": "overview",
|
|
"source": "Company Overview",
|
|
"text": f"{data['company']} — {data['description']}",
|
|
})
|
|
|
|
# Plans
|
|
for plan_key, plan in data.get("plans", {}).items():
|
|
features = ", ".join(plan.get("features", []))
|
|
chunks.append({
|
|
"id": f"plan_{plan_key}",
|
|
"source": f"{plan_key.title()} Plan",
|
|
"text": (
|
|
f"{plan_key.title()} Plan costs {plan['price']}. "
|
|
f"Features include: {features}."
|
|
),
|
|
})
|
|
|
|
# Policies
|
|
for policy_key, policy_val in data.get("policies", {}).items():
|
|
chunks.append({
|
|
"id": f"policy_{policy_key}",
|
|
"source": f"Policy: {policy_key.replace('_', ' ').title()}",
|
|
"text": f"{policy_key.replace('_', ' ').title()}: {policy_val}",
|
|
})
|
|
|
|
# FAQ (if present)
|
|
for faq in data.get("faq", []):
|
|
chunks.append({
|
|
"id": f"faq_{len(chunks)}",
|
|
"source": "FAQ",
|
|
"text": f"Q: {faq['question']} A: {faq['answer']}",
|
|
})
|
|
|
|
logger.info(f"Knowledge base chunked into {len(chunks)} segments.")
|
|
return chunks
|
|
|
|
|
|
# ── FAISS-backed retriever ─────────────────────────────────────────────────────
|
|
|
|
class FAISSRetriever:
|
|
def __init__(self, chunks: list[dict], model_name: str = "all-MiniLM-L6-v2"):
|
|
self.chunks = chunks
|
|
self.model = SentenceTransformer(model_name)
|
|
texts = [c["text"] for c in chunks]
|
|
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
|
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
|
self.index.add(embeddings.astype("float32"))
|
|
logger.info(f"FAISS index built with {len(chunks)} chunks.")
|
|
|
|
def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
|
|
q_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
|
|
distances, indices = self.index.search(q_vec, top_k)
|
|
results = []
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
if idx < len(self.chunks):
|
|
results.append({**self.chunks[idx], "score": float(dist)})
|
|
logger.debug(f"RAG retrieved {len(results)} chunks for: '{query}'")
|
|
return results
|
|
|
|
|
|
# ── Keyword fallback retriever ─────────────────────────────────────────────────
|
|
|
|
class KeywordRetriever:
|
|
def __init__(self, chunks: list[dict]):
|
|
self.chunks = chunks
|
|
|
|
def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
|
|
tokens = set(re.findall(r"\w+", query.lower()))
|
|
scored = []
|
|
for chunk in self.chunks:
|
|
chunk_tokens = set(re.findall(r"\w+", chunk["text"].lower()))
|
|
score = len(tokens & chunk_tokens)
|
|
scored.append((score, chunk))
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
results = [c for _, c in scored[:top_k]]
|
|
logger.debug(f"Keyword RAG retrieved {len(results)} chunks for: '{query}'")
|
|
return results
|
|
|
|
|
|
# ── Public RAG interface ───────────────────────────────────────────────────────
|
|
|
|
class RAGPipeline:
|
|
def __init__(self):
|
|
data = load_knowledge_base()
|
|
self.raw_data = data
|
|
self.chunks = chunk_knowledge_base(data)
|
|
|
|
if _FAISS_AVAILABLE:
|
|
try:
|
|
self.retriever = FAISSRetriever(self.chunks)
|
|
logger.info("Using FAISS semantic retriever.")
|
|
except Exception as e:
|
|
logger.warning(f"FAISS init failed ({e}), falling back to keyword.")
|
|
self.retriever = KeywordRetriever(self.chunks)
|
|
else:
|
|
self.retriever = KeywordRetriever(self.chunks)
|
|
logger.info("Using keyword-based retriever (FAISS not available).")
|
|
|
|
def get_context(self, query: str, top_k: int = 3) -> str:
|
|
"""Return a formatted context string for the LLM prompt."""
|
|
hits = self.retriever.retrieve(query, top_k=top_k)
|
|
if not hits:
|
|
return "No specific information found in the knowledge base."
|
|
parts = [f"[{h['source']}] {h['text']}" for h in hits]
|
|
return "\n".join(parts)
|
|
|
|
def full_context(self) -> str:
|
|
"""Return ALL chunks as a single context string (for grounding)."""
|
|
return "\n".join(c["text"] for c in self.chunks)
|