diff --git a/docs/api-quickstart.md b/docs/api-quickstart.md index c022a98..eba91aa 100644 --- a/docs/api-quickstart.md +++ b/docs/api-quickstart.md @@ -20,13 +20,15 @@ The Swagger interface (available at `https://myserver.com/docs`) can be used to Processes a single user message along with various model and search parameters, then returns a generated response along with any relevant resource URLs. +**Authentication Required**: Include the token in the Authorization header. + #### Request Body JSON object matching the schema: | Field | Type | Constraints | Description | |-------------------------|---------|-----------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `content` | string | Required | The user’s message or prompt content. | +| `content` | string | Required | The user's message or prompt content. | | `similarity_threshold` | float | Default: `config.search_similarity_threshold` Range: (0, 1] | Filter for relevant documents by cosine similarity. A higher threshold yields fewer but more precise documents, while a lower threshold is more inclusive. | | `temperature` | float | Default: `config.default_temperature` Range: [0.1, 1.0] | Controls the variability in generated responses. A value closer to 1.0 produces more creative/flexible answers; near 0.1 yields more deterministic results. | | `max_tokens` | int | Default: `config.default_max_tokens` Range: [1, 1024] | Limits the maximum length of the generated response. | @@ -39,6 +41,7 @@ JSON object matching the schema: ```json POST /prompt Content-Type: application/json +Authorization: Bearer 6e63fb8d-93a2-4c55-8694-c0e76a4a7233 { "content": "Why did my CI job fail?", @@ -94,6 +97,7 @@ Content-Type: application/json ```bash curl -X POST https://my-server.com/prompt \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer 6e63fb8d-93a2-4c55-8694-c0e76a4a7233" \ -d '{ "content": "What caused my test job to fail on environment X?", "similarity_threshold": 0.8, @@ -121,6 +125,8 @@ Response: Extracts Root Cause Analyses (RCAs) from a Tempest test report URL. This endpoint fetches the HTML report, parses out failed tests and their tracebacks, and generates an RCA for each unique test failure. +**Authentication Required**: Include the token in the Authorization header. + #### Request Body JSON object matching the schema: @@ -134,6 +140,7 @@ JSON object matching the schema: ```json POST /rca-from-tempest Content-Type: application/json +Authorization: Bearer 6e63fb8d-93a2-4c55-8694-c0e76a4a7233 { "tempest_report_url": "https://storage.example.com/ci-logs/tempest-report.html" @@ -180,6 +187,7 @@ Content-Type: application/json ```bash curl -X POST https://my-server.com/rca-from-tempest \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer 6e63fb8d-93a2-4c55-8694-c0e76a4a7233" \ -d '{ "tempest_report_url": "https://storage.example.com/ci-logs/tempest-report.html" }' diff --git a/src/rca_accelerator_chatbot/api.py b/src/rca_accelerator_chatbot/api.py index 51438c9..0544a49 100644 --- a/src/rca_accelerator_chatbot/api.py +++ b/src/rca_accelerator_chatbot/api.py @@ -2,13 +2,14 @@ FastAPI endpoints for the RCAccelerator API. """ import asyncio -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional import re import httpx from httpx_gssapi import HTTPSPNEGOAuth, OPTIONAL from bs4 import BeautifulSoup -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Security +from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field, HttpUrl from rca_accelerator_chatbot.constants import ( @@ -19,9 +20,12 @@ from rca_accelerator_chatbot.settings import ModelSettings from rca_accelerator_chatbot.generation import discover_generative_model_names from rca_accelerator_chatbot.embeddings import discover_embeddings_model_names +from rca_accelerator_chatbot.auth import authentification app = FastAPI(title="RCAccelerator API") +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + class BaseModelSettings(BaseModel): """Base model with common settings for model configuration.""" similarity_threshold: float = Field( @@ -89,6 +93,41 @@ async def validate_rca_settings(request: RcaRequest) -> RcaRequest: return await validate_settings(request) +async def get_current_user(authorization: Optional[str] = Security(api_key_header)) -> str: + """ + Validate the authorization token and return the username. + This function is used as a dependency for protected endpoints. + """ + if not authorization: + raise HTTPException( + status_code=401, + detail="Authorization header is missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Extract the token from the Authorization header + token_parts = authorization.split() + if len(token_parts) != 2 or token_parts[0].lower() != "bearer": + raise HTTPException( + status_code=401, + detail="Invalid authorization header format. Use 'Bearer {token}'", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = token_parts[1] + + # Verify the token + username = authentification.verify_token(token) + if not username: + raise HTTPException( + status_code=401, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return username + + class RcaResponseItem(BaseModel): """Response item for a single RCA.""" test_name: str @@ -167,10 +206,12 @@ async def fetch_and_parse_tempest_report(url: str) -> List[Dict[str, str]]: @app.post("/prompt") async def process_prompt( - message_data: ChatRequest = Depends(validate_chat_settings) + message_data: ChatRequest = Depends(validate_chat_settings), + _: str = Depends(get_current_user) ) -> Dict[str, Any]: """ FastAPI endpoint that processes a message and returns an answer. + Authentication required. """ generative_model_settings: ModelSettings = { "model": message_data.generative_model_name, @@ -198,10 +239,12 @@ async def process_prompt( @app.post("/rca-from-tempest", response_model=List[RcaResponseItem]) async def process_rca( - request: RcaRequest = Depends(validate_rca_settings) + request: RcaRequest = Depends(validate_rca_settings), + _: str = Depends(get_current_user) ) -> List[RcaResponseItem]: """ FastAPI endpoint that extracts Root Cause Analyses (RCAs) from a Tempest report URL. + Authentication required. """ traceback_items = await fetch_and_parse_tempest_report(str(request.tempest_report_url)) diff --git a/src/rca_accelerator_chatbot/app.py b/src/rca_accelerator_chatbot/app.py index be4c7e3..ca996bf 100644 --- a/src/rca_accelerator_chatbot/app.py +++ b/src/rca_accelerator_chatbot/app.py @@ -157,9 +157,15 @@ async def main(message: cl.Message): async def auth_callback(username: str, password: str): """ Authentication callback to validate user credentials. - Returns True if authentication is successful, False otherwise. + Returns a User object if authentication is successful, None otherwise. """ - return authentification.authenticate(username, password) + authenticated_username = authentification.authenticate(username, password) + if authenticated_username: + cl.logger.info("User %s authenticated successfully.", authenticated_username) + return cl.User(identifier=authenticated_username) + + cl.logger.error("Authentication failed for user %s.", username) + return None @cl.on_chat_resume diff --git a/src/rca_accelerator_chatbot/auth.py b/src/rca_accelerator_chatbot/auth.py index f8603c5..140a822 100644 --- a/src/rca_accelerator_chatbot/auth.py +++ b/src/rca_accelerator_chatbot/auth.py @@ -3,10 +3,12 @@ """ from abc import ABC, abstractmethod -from sqlalchemy import create_engine, MetaData, Table +from datetime import datetime, timezone +from sqlalchemy import create_engine, MetaData +from sqlalchemy.sql import select from sqlalchemy.orm import sessionmaker +from sqlalchemy.exc import SQLAlchemyError from bcrypt import checkpw -import chainlit as cl from rca_accelerator_chatbot.config import config @@ -16,10 +18,15 @@ class Authentification(ABC): """Abstract base class for user authentication.""" @abstractmethod - def authenticate(self, username: str, password: str) -> bool: + def authenticate(self, username: str, password: str) -> str | None: """Authenticate a user by username and password.""" raise NotImplementedError + @abstractmethod + def verify_token(self, token: str) -> str | None: + """Verify a token and return the associated username if valid.""" + raise NotImplementedError + # pylint: disable=too-many-instance-attributes,too-few-public-methods class DatabaseAuthentification(Authentification): @@ -27,11 +34,14 @@ class DatabaseAuthentification(Authentification): def __init__(self): self.database_url = config.auth_database_url - self.metadata = None - self.users_table = None if not self.database_url: raise ValueError("AUTH_DATABASE_URL environment variable " + "is not set.") + self.engine = None + self.session = None + self.metadata = MetaData() + self.users_table = None + self.tokens_table = None self.connect() def connect(self): @@ -39,7 +49,30 @@ def connect(self): self.engine = create_engine(self.database_url) self.session = sessionmaker(bind=self.engine) - def authenticate(self, username: str, password: str) -> cl.User | None: + # Initialize metadata and tables + self.metadata.reflect(bind=self.engine) + if 'users' in self.metadata.tables: + self.users_table = self.metadata.tables['users'] + if 'tokens' in self.metadata.tables: + self.tokens_table = self.metadata.tables['tokens'] + + def _load_table(self, table_name): + """ + Load a table from the database metadata. + Args: + table_name: Name of the table to load + Returns: + Table object if successful, None otherwise + """ + try: + self.metadata.reflect(bind=self.engine) + if table_name in self.metadata.tables: + return self.metadata.tables[table_name] + return None + except SQLAlchemyError: + return None + + def authenticate(self, username: str, password: str) -> str | None: """ Authenticate a user by checking the username and password against the database. @@ -47,15 +80,16 @@ def authenticate(self, username: str, password: str) -> cl.User | None: username: Username of the user password: Password of the user Returns: - cl.User: User object if authentication is successful, - None otherwise + str: Username if authentication is successful, None otherwise """ auth_ok = False - self.metadata = MetaData() - self.metadata.reflect(bind=self.engine) - self.users_table = Table('users', self.metadata, - autoload_with=self.engine) + # Make sure tables are loaded + if self.users_table is None: + self.users_table = self._load_table('users') + if self.users_table is None: + return None + auth_session = self.session() try: user = auth_session.query(self.users_table).filter_by( @@ -63,16 +97,46 @@ def authenticate(self, username: str, password: str) -> cl.User | None: if user and checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')): auth_ok = True + except SQLAlchemyError: + # Log the error in a production environment + auth_ok = False finally: auth_session.close() if auth_ok: - cl.logger.info("User %s authenticated successfully.", username) - return cl.User( - identifier=username, - ) - cl.logger.error("Authentication failed for user %s.", username) + return username return None + def verify_token(self, token: str) -> str | None: + """ + Verify if a token is valid and return the associated username. + Args: + token: The token to verify + Returns: + str: Username if the token is valid, None otherwise + """ + # Make sure tokens table is loaded + if self.tokens_table is None: + self.tokens_table = self._load_table('tokens') + if self.tokens_table is None: + return None + + auth_session = self.session() + try: + query = select(self.tokens_table).where( + (self.tokens_table.c.token == token) & + (self.tokens_table.c.expires_at > datetime.now(timezone.utc)) + ) + result = auth_session.execute(query).fetchone() + + if result: + return result.username + return None + except SQLAlchemyError: + # Handle database errors gracefully + return None + finally: + auth_session.close() + authentification = DatabaseAuthentification()