Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import aioredis
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security import HTTPBasic, HTTPBasicCredentials, OAuth2PasswordBearer
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.model_endpoints import BrokerType
from model_engine_server.common.env_vars import CIRCLECI
Expand Down Expand Up @@ -131,7 +131,8 @@

logger = make_logger(logger_name())

AUTH = HTTPBasic(auto_error=False)
basic_auth = HTTPBasic(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)


@dataclass
Expand Down Expand Up @@ -433,35 +434,60 @@ async def get_auth_repository():


async def verify_authentication(
credentials: HTTPBasicCredentials = Depends(AUTH),
credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth),
tokens: Optional[str] = Depends(oauth2_scheme),
auth_repo: AuthenticationRepository = Depends(get_auth_repository),
) -> User:
"""
Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise,
raises a 401.
"""
username = credentials.username if credentials is not None else None
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No authentication was passed in",
headers={"WWW-Authenticate": "Basic"},
)

auth = await auth_repo.get_auth_from_username_async(username=username)

if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not authenticate user",
headers={"WWW-Authenticate": "Basic"},
)

# set logger context with identity data
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)

return auth
# Basic Authentication
if credentials is not None:
username = credentials.username if credentials is not None else None
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No authentication was passed in",
headers={"WWW-Authenticate": "Basic"},
)

auth = await auth_repo.get_auth_from_username_async(username=username)

if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not authenticate user",
headers={"WWW-Authenticate": "Basic"},
)

# set logger context with identity data
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)

return auth

# bearer token
if tokens is not None:
auth = await auth_repo.get_auth_from_username_async(username=tokens)
if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not authenticate user",
headers={"WWW-Authenticate": "Bearer"},
)

# set logger context with identity data
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)

return auth

raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No authentication was passed in",
headers={"WWW-Authenticate": "Bearer"},
)


_pool: Optional[aioredis.BlockingConnectionPool] = None
Expand Down
15 changes: 10 additions & 5 deletions model-engine/tests/unit/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import datetime
from typing import Any, Dict, Iterator, Tuple
from typing import Any, Dict, Iterator, Optional, Tuple

import pytest
import pytest_asyncio
Expand All @@ -10,9 +10,10 @@
from httpx import AsyncClient
from model_engine_server.api.app import app
from model_engine_server.api.dependencies import (
AUTH,
basic_auth,
get_external_interfaces,
get_external_interfaces_read_only,
oauth2_scheme,
verify_authentication,
)
from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User
Expand Down Expand Up @@ -65,15 +66,19 @@ def get_test_auth_repository() -> Iterator[AuthenticationRepository]:


def fake_verify_authentication(
credentials: HTTPBasicCredentials = Depends(AUTH),
credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth),
tokens: Optional[str] = Depends(oauth2_scheme),
auth_repo: AuthenticationRepository = Depends(get_test_auth_repository),
) -> User:
"""
Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise,
raises a 401.
"""
auth_username = credentials.username if credentials is not None else None
if not auth_username:
if credentials is not None:
auth_username = credentials.username
elif tokens is not None:
auth_username = tokens
else:
raise HTTPException(status_code=401, detail="No authentication was passed in")

auth = auth_repo.get_auth_from_username(username=auth_username)
Expand Down