fact-checker-backend/app/services/openai_client.py
2024-12-12 17:31:44 +06:00

173 lines
No EOL
7.1 KiB
Python

from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.document_transformers import BeautifulSoupTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from typing import List, Dict, Any
import numpy as np
import logging as logger
import openai
import json
class OpenAIClient:
def __init__(self, api_key: str):
"""
Initialize OpenAI client with the provided API key.
"""
openai.api_key = api_key
async def generate_text_response(self, system_prompt: str, user_prompt: str, max_tokens: int) -> dict:
"""
Generate a response using OpenAI's chat completion API.
"""
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=max_tokens
)
content = response['choices'][0]['message']['content']
# Parse the JSON string into a dictionary
parsed_content = json.loads(content)
return {
"response": parsed_content, # Now returns a dictionary instead of string
"prompt_tokens": response['usage']['prompt_tokens'],
"completion_tokens": response['usage']['completion_tokens'],
"total_tokens": response['usage']['total_tokens']
}
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse OpenAI response as JSON: {str(e)}")
except Exception as e:
raise Exception(f"OpenAI text generation error: {str(e)}")
def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""
Retrieve embeddings for a list of texts using OpenAI's embedding API.
"""
try:
response = openai.Embedding.create(
input=texts,
model="text-embedding-ada-002"
)
embeddings = [data['embedding'] for data in response['data']]
return embeddings
except Exception as e:
raise Exception(f"OpenAI embedding error: {str(e)}")
class AIFactChecker:
def __init__(self, openai_client: OpenAIClient):
"""Initialize the fact checker with OpenAI client."""
self.openai_client = openai_client
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)
async def scrape_webpage(self, url: str) -> List[Document]:
"""Scrape webpage content using LangChain's AsyncHtmlLoader."""
try:
loader = AsyncHtmlLoader([url])
docs = await loader.aload()
bs_transformer = BeautifulSoupTransformer()
docs_transformed = bs_transformer.transform_documents(docs)
docs_chunks = self.text_splitter.split_documents(docs_transformed)
logger.info(f"Successfully scraped webpage | chunks={len(docs_chunks)}")
return docs_chunks
except Exception as e:
logger.error(f"Error scraping webpage | url={url} | error={str(e)}")
raise
def find_relevant_chunks(
self,
query_embedding: List[float],
doc_embeddings: List[List[float]],
docs: List[Document]
) -> List[Document]:
"""Find most relevant document chunks using cosine similarity."""
try:
query_array = np.array(query_embedding)
chunks_array = np.array(doc_embeddings)
similarities = np.dot(chunks_array, query_array) / (
np.linalg.norm(chunks_array, axis=1) * np.linalg.norm(query_array)
)
top_indices = np.argsort(similarities)[-5:][::-1]
return [docs[i] for i in top_indices]
except Exception as e:
logger.error(f"Error finding relevant chunks | error={str(e)}")
raise
async def verify_fact(self, query: str, relevant_docs: List[Document]) -> Dict[str, Any]:
"""Verify fact using OpenAI's API with context from relevant documents."""
try:
context = "\n\n".join([doc.page_content for doc in relevant_docs])
system_prompt = """You are a professional fact-checking assistant. Analyze the provided context
and determine if the given statement is true, false, or if there isn't enough information.
Provide your response in the following JSON format:
{
"verdict": "True/False/Insufficient Information",
"confidence": "High/Medium/Low",
"evidence": "Direct quotes or evidence from the context",
"reasoning": "Your detailed analysis and reasoning",
"missing_info": "Any important missing information (if applicable)"
}"""
user_prompt = f"""Context:
{context}
Statement to verify: "{query}"
Analyze the statement based on the provided context and return your response in the specified JSON format."""
response = await self.openai_client.generate_text_response(
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=800
)
sources = list(set([doc.metadata.get('source', 'Unknown source') for doc in relevant_docs]))
return {
"verification_result": response["response"], # This is now a dictionary
"sources": sources,
"context_used": [doc.page_content for doc in relevant_docs],
"token_usage": {
"prompt_tokens": response["prompt_tokens"],
"completion_tokens": response["completion_tokens"],
"total_tokens": response["total_tokens"]
}
}
except Exception as e:
logger.error(f"Error verifying fact | error={str(e)}")
raise
async def check_fact(self, url: str, query: str) -> Dict[str, Any]:
"""Main method to check a fact against a webpage."""
try:
docs = await self.scrape_webpage(url)
doc_texts = [doc.page_content for doc in docs]
doc_embeddings = self.openai_client.get_embeddings(doc_texts)
query_embedding = self.openai_client.get_embeddings([query])
relevant_docs = self.find_relevant_chunks(query_embedding[0], doc_embeddings, docs)
verification_result = await self.verify_fact(query, relevant_docs)
return verification_result
except Exception as e:
logger.error(f"Error checking fact | error={str(e)}")
raise