160 lines
No EOL
5.3 KiB
Python
160 lines
No EOL
5.3 KiB
Python
from fastapi import APIRouter, HTTPException
|
|
import httpx
|
|
import logging
|
|
from urllib.parse import urlparse
|
|
from typing import List, Dict, Optional
|
|
from pydantic import BaseModel
|
|
from app.models.ai_fact_check_models import (
|
|
AIFactCheckRequest,
|
|
FactCheckSource,
|
|
SourceType
|
|
)
|
|
from app.websites.fact_checker_website import SOURCES, get_all_sources
|
|
from app.api.ai_fact_check import ai_fact_check
|
|
from app.config import GOOGLE_API_KEY, GOOGLE_ENGINE_ID, GOOGLE_SEARCH_URL
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
search_text: str
|
|
source_types: List[str] = ["fact_checkers"]
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
scrap_websites_router = APIRouter()
|
|
|
|
# Constants
|
|
RESULTS_PER_PAGE = 10
|
|
MAX_PAGES = 5
|
|
MAX_URLS_PER_DOMAIN = 5
|
|
|
|
|
|
def get_domain_from_url(url: str) -> str:
|
|
"""Extract domain from URL with improved handling."""
|
|
try:
|
|
parsed = urlparse(url)
|
|
domain = parsed.netloc.lower()
|
|
if domain.startswith('www.'):
|
|
domain = domain[4:]
|
|
return domain
|
|
except Exception as e:
|
|
logger.error(f"Error extracting domain from URL {url}: {str(e)}")
|
|
return ""
|
|
|
|
def is_valid_source_domain(domain: str, sources: List[FactCheckSource]) -> bool:
|
|
"""Check if domain matches any source with improved matching logic."""
|
|
if not domain:
|
|
return False
|
|
|
|
domain = domain.lower()
|
|
if domain.startswith('www.'):
|
|
domain = domain[4:]
|
|
|
|
for source in sources:
|
|
source_domain = source.domain.lower()
|
|
if source_domain.startswith('www.'):
|
|
source_domain = source_domain[4:]
|
|
|
|
if domain == source_domain or domain.endswith('.' + source_domain):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def build_enhanced_search_query(query: str, sources: List[FactCheckSource]) -> str:
|
|
"""Build search query with site restrictions."""
|
|
site_queries = [f"site:{source.domain}" for source in sources]
|
|
site_restriction = " OR ".join(site_queries)
|
|
return f"({query}) ({site_restriction})"
|
|
|
|
async def google_custom_search(query: str, sources: List[FactCheckSource], page: int = 1) -> Optional[Dict]:
|
|
"""Perform Google Custom Search with enhanced query."""
|
|
enhanced_query = await build_enhanced_search_query(query, sources)
|
|
start_index = ((page - 1) * RESULTS_PER_PAGE) + 1
|
|
|
|
params = {
|
|
"key": GOOGLE_API_KEY,
|
|
"cx": GOOGLE_ENGINE_ID,
|
|
"q": enhanced_query,
|
|
"num": RESULTS_PER_PAGE,
|
|
"start": start_index
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
try:
|
|
response = await client.get(GOOGLE_SEARCH_URL, params=params)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except Exception as e:
|
|
logger.error(f"Search error: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Search error: {str(e)}")
|
|
|
|
@scrap_websites_router.post("/search")
|
|
async def search_websites(request: SearchRequest):
|
|
# Get the source types from the request
|
|
source_types = request.source_types if request.source_types else ["fact_checkers"]
|
|
|
|
# Get sources based on requested types
|
|
selected_sources = []
|
|
for source_type in source_types:
|
|
if source_type in SOURCES:
|
|
selected_sources.extend(SOURCES[source_type])
|
|
|
|
# If no valid sources found, use fact checkers as default
|
|
if not selected_sources:
|
|
selected_sources = SOURCES["fact_checkers"]
|
|
|
|
all_urls = []
|
|
domain_results = {}
|
|
|
|
try:
|
|
for page in range(1, MAX_PAGES + 1):
|
|
if len(all_urls) >= 50:
|
|
break
|
|
|
|
search_response = await google_custom_search(request.search_text, selected_sources, page)
|
|
|
|
if not search_response or not search_response.get("items"):
|
|
break
|
|
|
|
for item in search_response.get("items", []):
|
|
url = item.get("link")
|
|
if not url:
|
|
continue
|
|
|
|
domain = get_domain_from_url(url)
|
|
|
|
if is_valid_source_domain(domain, selected_sources):
|
|
if domain not in domain_results:
|
|
domain_results[domain] = []
|
|
|
|
if len(domain_results[domain]) < MAX_URLS_PER_DOMAIN:
|
|
domain_results[domain].append({
|
|
"url": url,
|
|
"title": item.get("title", ""),
|
|
"snippet": item.get("snippet", "")
|
|
})
|
|
all_urls.append(url)
|
|
|
|
if len(all_urls) >= 50:
|
|
break
|
|
|
|
if not all_urls:
|
|
return {
|
|
"status": "no_results",
|
|
"urls_found": 0
|
|
}
|
|
|
|
fact_check_request = AIFactCheckRequest(
|
|
content=request.search_text,
|
|
urls=all_urls[:5]
|
|
)
|
|
|
|
return await ai_fact_check(fact_check_request)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during search/fact-check process: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e)) |