private-gpt/private_gpt/users/crud/base.py
2024-04-23 17:48:13 +05:45

87 lines
No EOL
2.9 KiB
Python

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from private_gpt.users.db.base import Base
from fastapi import HTTPException, status
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from pydantic.error_wrappers import ValidationError
# Define custom types for SQLAlchemy model, and Pydantic schemas
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""Base class that can be extend by other action classes.
Provides basic CRUD and listing operations.
:param model: The SQLAlchemy model
:type model: Type[ModelType]
"""
self.model = model
def get_multi(
self, db: Session,
) -> List[ModelType]:
return db.query(self.model).all()
def get(self, db: Session, id: int) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Integrity Error: {str(e)}",
)
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Integrity Error: {str(e)}",
)
def remove(self, db: Session, *, id: int) -> ModelType:
obj = db.query(self.model).get(id)
if obj:
db.delete(obj)
db.commit()
return obj
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{self.model.__name__} not found with id: {id}",
)