private-gpt/private_gpt/users/api/v1/routers/auth.py
2024-05-02 10:58:03 +05:45

306 lines
11 KiB
Python

import logging
import traceback
from typing import Any, Optional
from datetime import timedelta, datetime
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
from private_gpt.users.api import deps
from private_gpt.users.core import security
from private_gpt.users.constants.role import Role
from private_gpt.users.core.config import settings
from private_gpt.users import crud, models, schemas
from private_gpt.users.utils import send_registration_email, Ldap
logger = logging.getLogger(__name__)
LDAP_SERVER = settings.LDAP_SERVER
LDAP_ENABLE = settings.LDAP_ENABLE
router = APIRouter(prefix="/auth", tags=["auth"])
def register_user(
db: Session,
email: str,
fullname: str,
password: str,
company: Optional[models.Company] = None,
department: Optional[models.Department] = None,
role: Optional[str] = None,
) -> models.User:
"""
Register a new user in the database.
"""
logging.info(f"User : {email} Password: {password} company_id: {company.id} deparment_id: {department.id}")
user_in = schemas.UserCreate(
email=email,
password=password,
username=fullname,
company_id=company.id,
department_id=department.id,
checker= True if role == 'OPERATOR' else False
)
try:
send_registration_email(fullname, email, password)
except Exception as e:
logging.info(f"Failed to send registration email: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to send registration email.")
return crud.user.create(db, obj_in=user_in)
def ldap_login(db, username, password):
ldap = Ldap(LDAP_SERVER, username, password)
username = ldap.who_am_i()
department = ldap.get_department(username)
if not ldap:
raise HTTPException(
status_code=400, detail="Incorrect email or password"
)
return username, department[0]
def create_user_role(
db: Session,
user: models.User,
role_name: str,
company: Optional[models.Company] = None,
) -> models.UserRole:
"""
Create a user role in the database.
"""
role = crud.role.get_by_name(db, name=role_name)
user_role_in = schemas.UserRoleCreate(user_id=user.id, role_id=role.id, company_id=company.id if company else None)
return crud.user_role.create(db, obj_in=user_role_in)
def create_token_payload(user: models.User, user_role: models.UserRole) -> dict:
"""
Create a token payload for authentication.
"""
return {
"id": str(user.id),
"email": str(user.email),
"role": user_role.role.name,
"username": str(user.username),
"company_id": user_role.company.id if user_role.company else None,
"department_id": user.department_id
}
def ad_user_register(
db: Session,
email: str,
fullname: str,
password: str,
department_id: int,
) -> models.User:
"""
Register a new user in the database. Company id is directly given here.
"""
user_in = schemas.UserCreate(email=email, password=password, username=fullname, company_id=1, department_id=department_id, checker=False)
user = crud.user.create(db, obj_in=user_in)
user_role_name = Role.GUEST["name"]
company = crud.company.get(db, 1)
user_role = create_user_role(db, user, user_role_name, company)
return user
@router.post("/login", response_model=schemas.TokenSchema)
def login_access_token(
log_audit: models.Audit = Depends(deps.get_audit_logger),
db: Session = Depends(deps.get_db),
form_data: OAuth2PasswordRequestForm = Depends(),
# active_subscription: models.Subscription = Depends(deps.get_active_subscription)
) -> Any:
"""
OAuth2 compatible token login, get an access token for future requests
"""
def ad_auth(LDAP_ENABLE):
if LDAP_ENABLE:
existing_user = crud.user.get_by_email(db, email=form_data.username)
if existing_user:
if existing_user.user_role.role.name == "SUPER_ADMIN":
return existing_user
else:
username, department = ldap_login(db=db, username=form_data.username, password=form_data.password)
return crud.user.get_by_name(db, name=username)
else:
username, department = ldap_login(db=db, username=form_data.username, password=form_data.password)
depart = crud.department.get_by_department_name(db, name=department)
if depart:
user = ad_user_register(db=db, email=form_data.username, fullname=username, password=form_data.password, department_id=depart.id)
else:
department_in = schemas.DepartmentCreate(name=department)
new_department = crud.department.create(db, obj_in=department_in)
user = ad_user_register(db=db, email=form_data.username, fullname=username, password=form_data.password, department_id=new_department.id)
return user
return None
if LDAP_ENABLE:
user = ad_auth(LDAP_ENABLE)
if not user:
raise HTTPException(
status_code=403,
detail="Invalid Credentials!!!",
)
else:
user = crud.user.authenticate(
db, email=form_data.username, password=form_data.password
)
if not user:
raise HTTPException(
status_code=400, detail="Incorrect email or password"
)
access_token_expires = timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
refresh_token_expires = timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
user_in = schemas.UserUpdate(
last_login=datetime.now()
)
user = crud.user.update(db, db_obj=user, obj_in=user_in)
if user.user_role:
role = user.user_role.role.name
if user.user_role.company_id:
company_id = user.user_role.company_id
else: company_id = None
token_payload = {
"id": str(user.id),
"email": str(user.email),
"username": str(user.username),
"role": role,
"company_id": company_id,
"department_id": str(user.department_id),
}
response_dict = {
"access_token": security.create_access_token(
token_payload, expires_delta=access_token_expires
),
"refresh_token": security.create_refresh_token(
token_payload, expires_delta=refresh_token_expires
),
"user": token_payload,
"token_type": "bearer",
}
log_audit(
model='User',
action='login',
details=token_payload,
user_id=user.id
)
return JSONResponse(content=response_dict)
@router.post("/login/refresh-token", response_model=schemas.TokenSchema)
def refresh_access_token(
db: Session = Depends(deps.get_db),
refresh_token: str = Body(..., embed=True),
) -> Any:
token_payload = security.verify_refresh_token(refresh_token)
if not token_payload:
raise HTTPException(status_code=401, detail="Invalid refresh token")
access_token_expires = timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
response_dict = {
"access_token": security.create_access_token(token_payload, expires_delta=access_token_expires),
"refresh_token": security.create_refresh_token(token_payload, expires_delta=refresh_token_expires),
"token_type": "bearer",
}
return JSONResponse(content=response_dict)
@router.post("/register", response_model=schemas.TokenSchema)
def register(
*,
log_audit: models.Audit = Depends(deps.get_audit_logger),
db: Session = Depends(deps.get_db),
email: str = Body(...),
fullname: str = Body(...),
# password: str = Body(...),
department_id: int = Body(None, title="Department ID",
description="Department name for the user (if applicable)"),
role_name: str = Body(None, title="Role Name",
description="User role name (if applicable)"),
current_user: models.User = Security(
deps.get_current_active_user,
scopes=[Role.ADMIN["name"],
Role.SUPER_ADMIN["name"],
Role.OPERATOR["name"]],
),
) -> Any:
"""
Register new user with optional company assignment and role selection.
"""
existing_user = crud.user.get_by_email(db, email=email)
if existing_user:
log_audit(
model='User',
action='creation',
details={"status": '409', 'detail': "The user with this email already exists!", },
user_id=current_user.id
)
raise HTTPException(
status_code=409,
detail="The user with this email already exists!",
)
random_password = security.generate_random_password()
# random_password = password
try:
company_id = current_user.company_id
if company_id:
company = crud.company.get(db, company_id)
if not company:
raise HTTPException(
status_code=404,
detail="Company not found.",
)
if department_id:
department = crud.department.get_by_id(
db=db, id=department_id)
if not department:
raise HTTPException(
status_code=404,
detail="Department not found.",
)
logging.info(f"Department is {department}")
user = register_user(
db, email, fullname, random_password, company, department, role_name
)
user_role_name = role_name or Role.GUEST["name"]
user_role = create_user_role(db, user, user_role_name, company)
log_audit(model='user_roles', action='create',
details={'detail': "User role created successfully.", }, user_id=current_user.id)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(
status_code=500,
detail="Unable to create account.",
)
token_payload = create_token_payload(user, user_role)
response_dict = {
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
"refresh_token": security.create_refresh_token(token_payload, expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)),
"token_type": "bearer",
}
log_audit(model='User', action='creation',
details={'detail': "User created successfully.",'username':fullname}, user_id=current_user.id)
return JSONResponse(content=response_dict, status_code=status.HTTP_201_CREATED)