261 lines
No EOL
9.9 KiB
Python
261 lines
No EOL
9.9 KiB
Python
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Dict, Optional
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
import urllib.parse
|
|
import numpy as np
|
|
from time import sleep
|
|
import logging
|
|
from app.services.openai_client import OpenAIClient
|
|
from app.config import OPENAI_API_KEY
|
|
from app.websites.fact_checker_website import SOURCES, get_all_sources
|
|
from app.api.ai_fact_check import ai_fact_check
|
|
from app.models.fact_check_models import AIFactCheckRequest, AIFactCheckResponse
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
scrap_websites_router = APIRouter()
|
|
|
|
# Configuration for rate limiting
|
|
RATE_LIMIT_DELAY = 2 # Delay between requests in seconds
|
|
MAX_RETRIES = 1 # Maximum number of retries per domain
|
|
RETRY_DELAY = 1 # Delay between retries in seconds
|
|
|
|
class SearchRequest(BaseModel):
|
|
search_text: str
|
|
source_types: List[str] = ["fact_checkers"]
|
|
|
|
class UrlSimilarityInfo(BaseModel):
|
|
url: str
|
|
similarity: float
|
|
extracted_text: str
|
|
|
|
class SearchResponse(BaseModel):
|
|
results: Dict[str, List[str]]
|
|
error_messages: Dict[str, str]
|
|
ai_fact_check_result: Optional[AIFactCheckResponse] = None
|
|
|
|
def extract_url_text(url: str) -> str:
|
|
"""Extract and process meaningful text from URL path with improved cleaning"""
|
|
logger.debug(f"Extracting text from URL: {url}")
|
|
try:
|
|
parsed = urllib.parse.urlparse(url)
|
|
path = parsed.path
|
|
path = path.replace('.html', '').replace('/index', '').replace('.php', '')
|
|
segments = [seg for seg in path.split('/') if seg]
|
|
cleaned_segments = []
|
|
for segment in segments:
|
|
segment = segment.replace('-', ' ').replace('_', ' ')
|
|
if not (segment.replace(' ', '').isdigit() or
|
|
all(part.isdigit() for part in segment.split() if part)):
|
|
cleaned_segments.append(segment)
|
|
|
|
common_words = {
|
|
'www', 'live', 'news', 'intl', 'index', 'world', 'us', 'uk',
|
|
'updates', 'update', 'latest', 'breaking', 'new', 'article'
|
|
}
|
|
|
|
text = ' '.join(cleaned_segments)
|
|
words = [word.lower() for word in text.split()
|
|
if word.lower() not in common_words and len(word) > 1]
|
|
|
|
result = ' '.join(words)
|
|
logger.debug(f"Extracted text: {result}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error extracting text from URL {url}: {str(e)}")
|
|
return ''
|
|
|
|
def google_search_scraper(search_text: str, site_domain: str, retry_count: int = 0) -> List[str]:
|
|
"""Scrape Google search results with retry logic and rate limiting"""
|
|
logger.info(f"Searching for '{search_text}' on domain: {site_domain} (Attempt {retry_count + 1}/{MAX_RETRIES})")
|
|
|
|
if retry_count >= MAX_RETRIES:
|
|
logger.error(f"Max retries exceeded for domain: {site_domain}")
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Max retries exceeded for {site_domain}"
|
|
)
|
|
|
|
query = f"{search_text} \"site:{site_domain}\""
|
|
encoded_query = urllib.parse.quote(query)
|
|
base_url = "https://www.google.com/search"
|
|
url = f"{base_url}?q={encoded_query}"
|
|
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
|
}
|
|
|
|
try:
|
|
logger.debug(f"Waiting {RATE_LIMIT_DELAY} seconds before request")
|
|
sleep(RATE_LIMIT_DELAY)
|
|
|
|
logger.debug(f"Making request to Google Search for domain: {site_domain}")
|
|
response = requests.get(url, headers=headers)
|
|
|
|
if response.status_code == 429 or "sorry/index" in response.url:
|
|
logger.warning(f"Rate limit hit for domain {site_domain}. Retrying after delay...")
|
|
sleep(RETRY_DELAY)
|
|
return google_search_scraper(search_text, site_domain, retry_count + 1)
|
|
|
|
response.raise_for_status()
|
|
|
|
soup = BeautifulSoup(response.content, 'html.parser')
|
|
search_results = soup.find_all('div', class_='g')
|
|
|
|
urls = []
|
|
for result in search_results[:3]:
|
|
link = result.find('a')
|
|
if link and 'href' in link.attrs:
|
|
url = link['href']
|
|
if url.startswith('http'):
|
|
urls.append(url)
|
|
|
|
logger.info(f"Found {len(urls)} results for domain: {site_domain}")
|
|
return urls[:5]
|
|
|
|
except requests.RequestException as e:
|
|
if retry_count < MAX_RETRIES:
|
|
logger.warning(f"Request failed for {site_domain}. Retrying... Error: {str(e)}")
|
|
sleep(RETRY_DELAY)
|
|
return google_search_scraper(search_text, site_domain, retry_count + 1)
|
|
logger.error(f"All retries failed for domain {site_domain}. Error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error scraping {site_domain}: {str(e)}"
|
|
)
|
|
|
|
def calculate_similarity(query_embedding: List[float], url_embedding: List[float]) -> float:
|
|
"""Calculate cosine similarity between two embeddings"""
|
|
query_array = np.array(query_embedding)
|
|
url_array = np.array(url_embedding)
|
|
|
|
similarity = np.dot(url_array, query_array) / (
|
|
np.linalg.norm(url_array) * np.linalg.norm(query_array)
|
|
)
|
|
return float(similarity)
|
|
|
|
@scrap_websites_router.post("/search", response_model=SearchResponse)
|
|
async def search_websites(request: SearchRequest):
|
|
logger.info(f"Starting search with query: {request.search_text}")
|
|
logger.info(f"Source types requested: {request.source_types}")
|
|
|
|
results = {}
|
|
error_messages = {}
|
|
url_similarities = {}
|
|
|
|
# Initialize OpenAI client
|
|
logger.debug("Initializing OpenAI client")
|
|
openai_client = OpenAIClient(OPENAI_API_KEY)
|
|
|
|
# Get domains based on requested source types
|
|
domains = []
|
|
for source_type in request.source_types:
|
|
if source_type in SOURCES:
|
|
domains.extend([source.domain for source in SOURCES[source_type]])
|
|
|
|
if not domains:
|
|
logger.warning("No valid source types provided. Using all available domains.")
|
|
domains = [source.domain for source in get_all_sources()]
|
|
|
|
logger.info(f"Processing {len(domains)} domains")
|
|
|
|
# Enhance search text with key terms
|
|
search_context = request.search_text
|
|
logger.debug("Getting query embedding from OpenAI")
|
|
query_embedding = openai_client.get_embeddings([search_context])[0]
|
|
|
|
# Higher similarity threshold for better filtering
|
|
SIMILARITY_THRESHOLD = 0.75
|
|
|
|
for domain in domains:
|
|
logger.info(f"Processing domain: {domain}")
|
|
try:
|
|
urls = google_search_scraper(request.search_text, domain)
|
|
url_sims = []
|
|
valid_urls = []
|
|
|
|
logger.debug(f"Found {len(urls)} URLs for domain {domain}")
|
|
|
|
for url in urls:
|
|
url_text = extract_url_text(url)
|
|
|
|
if not url_text:
|
|
logger.debug(f"No meaningful text extracted from URL: {url}")
|
|
continue
|
|
|
|
logger.debug("Getting URL embedding from OpenAI")
|
|
url_embedding = openai_client.get_embeddings([url_text])[0]
|
|
similarity = calculate_similarity(query_embedding, url_embedding)
|
|
|
|
logger.debug(f"Similarity score for {url}: {similarity}")
|
|
|
|
url_sims.append(UrlSimilarityInfo(
|
|
url=url,
|
|
similarity=similarity,
|
|
extracted_text=url_text
|
|
))
|
|
|
|
if similarity >= SIMILARITY_THRESHOLD:
|
|
valid_urls.append(url)
|
|
|
|
results[domain] = valid_urls
|
|
url_similarities[domain] = sorted(url_sims,
|
|
key=lambda x: x.similarity,
|
|
reverse=True)
|
|
|
|
logger.info(f"Successfully processed domain {domain}. Found {len(valid_urls)} valid URLs")
|
|
|
|
except HTTPException as e:
|
|
logger.error(f"HTTP Exception for domain {domain}: {str(e.detail)}")
|
|
error_messages[domain] = str(e.detail)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error for domain {domain}: {str(e)}")
|
|
error_messages[domain] = f"Unexpected error for {domain}: {str(e)}"
|
|
|
|
logger.info("Search completed")
|
|
logger.debug(f"Results found for {len(results)} domains")
|
|
logger.debug(f"Errors encountered for {len(error_messages)} domains")
|
|
|
|
# Collect all valid URLs from results
|
|
all_valid_urls = []
|
|
for domain_urls in results.values():
|
|
all_valid_urls.extend(domain_urls)
|
|
|
|
logger.info(f"Total valid URLs collected: {len(all_valid_urls)}")
|
|
|
|
# Create request body for AI fact check
|
|
if all_valid_urls:
|
|
fact_check_request = AIFactCheckRequest(
|
|
content=request.search_text,
|
|
urls=all_valid_urls
|
|
)
|
|
|
|
logger.info("Calling AI fact check service")
|
|
try:
|
|
ai_response = await ai_fact_check(fact_check_request)
|
|
logger.info("AI fact check completed successfully")
|
|
|
|
# Return response with AI fact check results
|
|
return SearchResponse(
|
|
results=results,
|
|
error_messages=error_messages,
|
|
ai_fact_check_result=ai_response
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during AI fact check: {str(e)}")
|
|
error_messages["ai_fact_check"] = f"Error during fact checking: {str(e)}"
|
|
|
|
# Return response without AI fact check if no valid URLs or error occurred
|
|
return SearchResponse(
|
|
results=results,
|
|
error_messages=error_messages,
|
|
ai_fact_check_result=None
|
|
) |