201 lines
7.5 KiB
Python
201 lines
7.5 KiB
Python
from langchain_community.document_transformers import BeautifulSoupTransformer
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_core.documents import Document
|
|
from typing import List, Dict, Any
|
|
import numpy as np
|
|
import logging as logger
|
|
import openai
|
|
import json
|
|
import aiohttp
|
|
from bs4 import BeautifulSoup
|
|
|
|
|
|
class OpenAIClient:
|
|
def __init__(self, api_key: str):
|
|
"""
|
|
Initialize OpenAI client with the provided API key.
|
|
"""
|
|
openai.api_key = api_key
|
|
|
|
async def generate_text_response(
|
|
self, system_prompt: str, user_prompt: str, max_tokens: int
|
|
) -> dict:
|
|
"""
|
|
Generate a response using OpenAI's chat completion API.
|
|
"""
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
model="gpt-4",
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
max_tokens=max_tokens,
|
|
)
|
|
content = response["choices"][0]["message"]["content"]
|
|
# Parse the JSON string into a dictionary
|
|
parsed_content = json.loads(content)
|
|
|
|
return {
|
|
"response": parsed_content, # Now returns a dictionary instead of string
|
|
"prompt_tokens": response["usage"]["prompt_tokens"],
|
|
"completion_tokens": response["usage"]["completion_tokens"],
|
|
"total_tokens": response["usage"]["total_tokens"],
|
|
}
|
|
except json.JSONDecodeError as e:
|
|
raise Exception(f"Failed to parse OpenAI response as JSON: {str(e)}")
|
|
except Exception as e:
|
|
raise Exception(f"OpenAI text generation error: {str(e)}")
|
|
|
|
def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Retrieve embeddings for a list of texts using OpenAI's embedding API.
|
|
"""
|
|
try:
|
|
response = openai.Embedding.create(
|
|
input=texts, model="text-embedding-ada-002"
|
|
)
|
|
embeddings = [data["embedding"] for data in response["data"]]
|
|
return embeddings
|
|
except Exception as e:
|
|
raise Exception(f"OpenAI embedding error: {str(e)}")
|
|
|
|
|
|
class AIFactChecker:
|
|
def __init__(self, openai_client: OpenAIClient):
|
|
"""Initialize the fact checker with OpenAI client."""
|
|
self.openai_client = openai_client
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
length_function=len,
|
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
|
|
)
|
|
|
|
async def scrape_webpage(self, url: str) -> List[Document]:
|
|
"""Scrape webpage content without saving HTML files."""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
if response.status != 200:
|
|
raise Exception(
|
|
f"Failed to fetch URL: {url}, status: {response.status}"
|
|
)
|
|
|
|
html_content = await response.text()
|
|
|
|
# Parse HTML with BeautifulSoup
|
|
soup = BeautifulSoup(html_content, "html.parser")
|
|
|
|
# Create a Document with the parsed content
|
|
doc = Document(
|
|
page_content=soup.get_text(separator="\n", strip=True),
|
|
metadata={"source": url},
|
|
)
|
|
|
|
# Split into chunks
|
|
docs_chunks = self.text_splitter.split_documents([doc])
|
|
|
|
logger.info(
|
|
f"Successfully scraped webpage | chunks={len(docs_chunks)}"
|
|
)
|
|
return docs_chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error scraping webpage | url={url} | error={str(e)}")
|
|
raise
|
|
|
|
def find_relevant_chunks(
|
|
self,
|
|
query_embedding: List[float],
|
|
doc_embeddings: List[List[float]],
|
|
docs: List[Document],
|
|
) -> List[Document]:
|
|
"""Find most relevant document chunks using cosine similarity."""
|
|
try:
|
|
query_array = np.array(query_embedding)
|
|
chunks_array = np.array(doc_embeddings)
|
|
|
|
similarities = np.dot(chunks_array, query_array) / (
|
|
np.linalg.norm(chunks_array, axis=1) * np.linalg.norm(query_array)
|
|
)
|
|
|
|
top_indices = np.argsort(similarities)[-5:][::-1]
|
|
return [docs[i] for i in top_indices]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error finding relevant chunks | error={str(e)}")
|
|
raise
|
|
|
|
async def verify_fact(
|
|
self, query: str, relevant_docs: List[Document]
|
|
) -> Dict[str, Any]:
|
|
"""Verify fact using OpenAI's API with context from relevant documents."""
|
|
try:
|
|
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
|
|
|
system_prompt = """You are a professional fact-checking assistant. Analyze the provided context
|
|
and determine if the given statement is true, false, or if there isn't enough information.
|
|
|
|
Provide your response in the following JSON format:
|
|
{
|
|
"verdict": "True/False/Insufficient Information",
|
|
"confidence": "High/Medium/Low",
|
|
"evidence": "Direct quotes or evidence from the context",
|
|
"reasoning": "Your detailed analysis and reasoning",
|
|
"missing_info": "Any important missing information (if applicable)"
|
|
}"""
|
|
|
|
user_prompt = f"""Context:
|
|
{context}
|
|
|
|
Statement to verify: "{query}"
|
|
|
|
Analyze the statement based on the provided context and return your response in the specified JSON format."""
|
|
|
|
response = await self.openai_client.generate_text_response(
|
|
system_prompt=system_prompt, user_prompt=user_prompt, max_tokens=800
|
|
)
|
|
|
|
sources = list(
|
|
set(
|
|
[
|
|
doc.metadata.get("source", "Unknown source")
|
|
for doc in relevant_docs
|
|
]
|
|
)
|
|
)
|
|
|
|
return {
|
|
"verification_result": response["response"], # This is now a dictionary
|
|
"sources": sources,
|
|
"token_usage": {
|
|
"prompt_tokens": response["prompt_tokens"],
|
|
"completion_tokens": response["completion_tokens"],
|
|
"total_tokens": response["total_tokens"],
|
|
},
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error verifying fact | error={str(e)}")
|
|
raise
|
|
|
|
async def check_fact(self, url: str, query: str) -> Dict[str, Any]:
|
|
"""Main method to check a fact against a webpage."""
|
|
try:
|
|
docs = await self.scrape_webpage(url)
|
|
|
|
doc_texts = [doc.page_content for doc in docs]
|
|
doc_embeddings = self.openai_client.get_embeddings(doc_texts)
|
|
query_embedding = self.openai_client.get_embeddings([query])
|
|
|
|
relevant_docs = self.find_relevant_chunks(
|
|
query_embedding[0], doc_embeddings, docs
|
|
)
|
|
verification_result = await self.verify_fact(query, relevant_docs)
|
|
|
|
return verification_result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking fact | error={str(e)}")
|
|
raise
|