mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 23:22:57 +01:00
update at documents router
This commit is contained in:
parent
2008837110
commit
c7c05de8d1
12 changed files with 264 additions and 24 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
23
poetry.lock
generated
23
poetry.lock
generated
|
|
@ -1255,6 +1255,27 @@ uvicorn = {version = ">=0.12.0", extras = ["standard"], optional = true, markers
|
|||
[package.extras]
|
||||
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi-filter"
|
||||
version = "1.1.0"
|
||||
description = "FastAPI filter"
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "fastapi_filter-1.1.0-py3-none-any.whl", hash = "sha256:9807a65f76855580a51c232d51b15c5d69a033b6662ef445282a4f982280347f"},
|
||||
{file = "fastapi_filter-1.1.0.tar.gz", hash = "sha256:94e2fec595afbcc383ccd7cbcfabe38b83ff4a7971eb404e92e1ad1da29274b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fastapi = ">=0.100.0,<1.0"
|
||||
pydantic = ">=2.0.0,<3.0.0"
|
||||
SQLAlchemy = {version = ">=1.4.36,<2.1.0", optional = true, markers = "extra == \"sqlalchemy\" or extra == \"all\""}
|
||||
|
||||
[package.extras]
|
||||
all = ["SQLAlchemy (>=1.4.36,<2.1.0)", "mongoengine (>=0.24.1,<0.28.0)"]
|
||||
mongoengine = ["mongoengine (>=0.24.1,<0.28.0)"]
|
||||
sqlalchemy = ["SQLAlchemy (>=1.4.36,<2.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "ffmpy"
|
||||
version = "0.3.2"
|
||||
|
|
@ -7126,4 +7147,4 @@ vector-stores-qdrant = ["llama-index-vector-stores-qdrant"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.11,<3.12"
|
||||
content-hash = "98b94c8e94a361f61d52e33c22bb62703abfa63f72f1df0256b1db3de497c05e"
|
||||
content-hash = "b0715db49b2b01f8b6c1b3d122d93de58ef052be999114390fab9fa19cc9e794"
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ def register_user(
|
|||
password: str,
|
||||
company: Optional[models.Company] = None,
|
||||
department: Optional[models.Department] = None,
|
||||
role: Optional[str] = None,
|
||||
) -> models.User:
|
||||
"""
|
||||
Register a new user in the database.
|
||||
|
|
@ -40,6 +41,7 @@ def register_user(
|
|||
username=fullname,
|
||||
company_id=company.id,
|
||||
department_id=department.id,
|
||||
checker= True if role == 'OPERATOR' else False
|
||||
)
|
||||
# try:
|
||||
# send_registration_email(fullname, email, password)
|
||||
|
|
@ -97,7 +99,7 @@ def ad_user_register(
|
|||
"""
|
||||
Register a new user in the database. Company id is directly given here.
|
||||
"""
|
||||
user_in = schemas.UserCreate(email=email, password=password, username=fullname, company_id=1, department_id=department_id)
|
||||
user_in = schemas.UserCreate(email=email, password=password, username=fullname, company_id=1, department_id=department_id, checker=False)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
user_role_name = Role.GUEST["name"]
|
||||
company = crud.company.get(db, 1)
|
||||
|
|
@ -236,7 +238,9 @@ def register(
|
|||
description="User role name (if applicable)"),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_active_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
scopes=[Role.ADMIN["name"],
|
||||
Role.SUPER_ADMIN["name"],
|
||||
Role.OPERATOR["name"]],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
|
|
@ -277,7 +281,7 @@ def register(
|
|||
)
|
||||
logging.info(f"Department is {department}")
|
||||
user = register_user(
|
||||
db, email, fullname, random_password, company, department
|
||||
db, email, fullname, random_password, company, department, role_name
|
||||
)
|
||||
user_role_name = role_name or Role.GUEST["name"]
|
||||
user_role = create_user_role(db, user, user_role_name, company)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from datetime import datetime
|
|||
|
||||
from typing import Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security, Request, File, UploadFile
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
|
|
@ -16,10 +17,16 @@ from private_gpt.users import crud, models, schemas
|
|||
from private_gpt.server.ingest.ingest_router import create_documents, ingest
|
||||
from private_gpt.users.models.document import MakerCheckerActionType, MakerCheckerStatus
|
||||
from private_gpt.components.ocr_components.table_ocr_api import process_both_ocr, process_ocr
|
||||
from fastapi_filter import FilterDepends, with_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/documents', tags=['Documents'])
|
||||
|
||||
def get_username(db, id):
|
||||
user = crud.user.get_by_id(db=db, id=id)
|
||||
return user.username
|
||||
|
||||
CHECKER = True
|
||||
|
||||
@router.get("", response_model=List[schemas.DocumentView])
|
||||
def list_files(
|
||||
|
|
@ -35,14 +42,12 @@ def list_files(
|
|||
"""
|
||||
List the documents based on the role.
|
||||
"""
|
||||
def get_username(db, id):
|
||||
user = crud.user.get_by_id(db=db, id=id)
|
||||
return user.username
|
||||
|
||||
|
||||
try:
|
||||
role = current_user.user_role.role.name if current_user.user_role else None
|
||||
if (role == "SUPER_ADMIN") or (role == "OPERATOR"):
|
||||
docs = crud.documents.get_multi(db, skip=skip, limit=limit)
|
||||
docs = crud.documents.get_multi_documents(
|
||||
db, skip=skip, limit=limit)
|
||||
else:
|
||||
docs = crud.documents.get_documents_by_departments(
|
||||
db, department_id=current_user.department_id, skip=skip, limit=limit)
|
||||
|
|
@ -73,6 +78,51 @@ def list_files(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/pending", response_model=List[schemas.DocumentVerify])
|
||||
def list_pending_files(
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"], Role.OPERATOR["name"]],
|
||||
)
|
||||
):
|
||||
"""
|
||||
List the documents based on the role.
|
||||
"""
|
||||
def get_username(db, id):
|
||||
user = crud.user.get_by_id(db=db, id=id)
|
||||
return user.username
|
||||
|
||||
try:
|
||||
docs = crud.documents.get_files_to_verify(
|
||||
db, department_id=current_user.department_id, skip=skip, limit=limit)
|
||||
|
||||
documents = [
|
||||
schemas.DocumentVerify(
|
||||
id=doc.id,
|
||||
filename=doc.filename,
|
||||
uploaded_by=get_username(db, doc.uploaded_by),
|
||||
uploaded_at=doc.uploaded_at,
|
||||
departments=[
|
||||
schemas.DepartmentList(id=dep.id, name=dep.name)
|
||||
for dep in doc.departments
|
||||
],
|
||||
status=doc.status
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return documents
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
@router.get('{department_id}', response_model=List[schemas.DocumentList])
|
||||
def list_files_by_department(
|
||||
request: Request,
|
||||
|
|
@ -265,18 +315,18 @@ async def upload_documents(
|
|||
detail="Internal Server Error: Unable to upload file.",
|
||||
)
|
||||
|
||||
|
||||
@router.post('/verify')
|
||||
async def verify_documents(
|
||||
request: Request,
|
||||
checker_in: schemas.DocumentUpdate = Depends(),
|
||||
checker_in: schemas.DocumentUpdate,
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"],
|
||||
scopes=[
|
||||
Role.SUPER_ADMIN["name"],
|
||||
Role.OPERATOR["name"]],
|
||||
Role.OPERATOR["name"]
|
||||
],
|
||||
)
|
||||
):
|
||||
"""Upload the documents."""
|
||||
|
|
@ -287,17 +337,49 @@ async def verify_documents(
|
|||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Document not found!",
|
||||
)
|
||||
|
||||
if document.verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Document already verified!",
|
||||
)
|
||||
if CHECKER:
|
||||
print("Current user is checker: ", current_user.checker)
|
||||
if not current_user.checker:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You are not the checker!",
|
||||
)
|
||||
|
||||
if document.uploaded_by == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot verify by same user!",
|
||||
)
|
||||
|
||||
|
||||
unchecked_path = Path(f"{UNCHECKED_DIR}/{document.filename}")
|
||||
|
||||
if checker_in.status == MakerCheckerStatus.APPROVED.value:
|
||||
checker = schemas.DocumentCheckerUpdate(
|
||||
action_type=MakerCheckerActionType.UPDATE,
|
||||
status=MakerCheckerStatus.APPROVED,
|
||||
is_enabled=checker_in.is_enabled,
|
||||
is_enabled=False,
|
||||
verified_at=datetime.now(),
|
||||
verified_by=current_user.id,
|
||||
verified=True,
|
||||
)
|
||||
crud.documents.update(db=db, db_obj= document, obj_in=checker)
|
||||
|
||||
log_audit(
|
||||
model='Document',
|
||||
action='update',
|
||||
details={
|
||||
'filename': f'{document.filename}',
|
||||
'approved': f'{current_user.id}'
|
||||
},
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if document.doc_type_id == 2:
|
||||
return await process_ocr(request, unchecked_path)
|
||||
|
|
@ -306,6 +388,7 @@ async def verify_documents(
|
|||
else:
|
||||
return await ingest(request, unchecked_path)
|
||||
|
||||
|
||||
elif checker_in.status == MakerCheckerStatus.REJECTED.value:
|
||||
checker = schemas.DocumentCheckerUpdate(
|
||||
action_type=MakerCheckerActionType.DELETE,
|
||||
|
|
@ -313,10 +396,20 @@ async def verify_documents(
|
|||
is_enabled=False,
|
||||
verified_at=datetime.now(),
|
||||
verified_by=current_user.id,
|
||||
verified=True,
|
||||
)
|
||||
crud.documents.update(db=db, db_obj=document, obj_in=checker)
|
||||
os.remove(unchecked_path)
|
||||
|
||||
crud.documents.remove(db, id=document.id)
|
||||
log_audit(
|
||||
model='Document',
|
||||
action='update',
|
||||
details={
|
||||
'filename': f'{document.filename}',
|
||||
'rejected': f'{current_user.id}'
|
||||
},
|
||||
user_id=current_user.id
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -334,3 +427,56 @@ async def verify_documents(
|
|||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error: Unable to upload file.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_id(db, username):
|
||||
name = crud.user.get_by_name(db=db, name=username)
|
||||
return name
|
||||
|
||||
|
||||
@router.get('/filter', response_model=List[schemas.DocumentView])
|
||||
async def get_documents(
|
||||
document_filter: schemas.DocumentFilter = Depends(),
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[
|
||||
Role.SUPER_ADMIN["name"],
|
||||
Role.OPERATOR["name"]
|
||||
],
|
||||
)
|
||||
)-> Any:
|
||||
try:
|
||||
uploaded_by = get_id(db, document_filter.uploaded_by)
|
||||
id = uploaded_by.id if uploaded_by else None
|
||||
docs = crud.documents.filter_query(
|
||||
db=db,
|
||||
filename=document_filter.filename,
|
||||
uploaded_by=id,
|
||||
action_type=document_filter.action_type,
|
||||
status=document_filter.status,
|
||||
order_by=document_filter.order_by
|
||||
)
|
||||
|
||||
documents = [
|
||||
schemas.DocumentView(
|
||||
id=doc.id,
|
||||
filename=doc.filename,
|
||||
uploaded_by=get_username(db, doc.uploaded_by),
|
||||
uploaded_at=doc.uploaded_at,
|
||||
is_enabled=doc.is_enabled,
|
||||
departments=[
|
||||
schemas.DepartmentList(id=dep.id, name=dep.name)
|
||||
for dep in doc.departments
|
||||
],
|
||||
action_type=doc.action_type,
|
||||
status=doc.status
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return documents
|
||||
except Exception as e:
|
||||
print(traceback.print_exc())
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from sqlalchemy.sql.expression import desc, asc
|
||||
from sqlalchemy import or_, and_
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users.schemas.documents import DocumentCreate, DocumentUpdate
|
||||
|
|
@ -16,11 +17,11 @@ class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
|||
return db.query(self.model).filter(Document.filename == file_name).first()
|
||||
|
||||
def get_multi_documents(
|
||||
self, db: Session, *,department_id: int, skip: int = 0, limit: int = 100
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(Document.department_id == department_id)
|
||||
.order_by(desc(getattr(Document, 'uploaded_at')))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
|
|
@ -36,13 +37,24 @@ class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
|||
.filter(document_department_association.c.department_id == department_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all().order_by(desc(getattr(Document, 'uploaded_at')))
|
||||
)
|
||||
|
||||
def get_files_to_verify(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(Document.status == 'PENDING')
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_enabled_documents_by_departments(
|
||||
self, db: Session, *, department_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
all_department_id = 4 # department ID for "ALL" is 4
|
||||
all_department_id = 1 # department ID for "ALL" is 1
|
||||
|
||||
return (
|
||||
db.query(self.model)
|
||||
|
|
@ -65,5 +77,33 @@ class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
|||
.all()
|
||||
)
|
||||
|
||||
def filter_query(
|
||||
self, db: Session, *,
|
||||
filename: Optional[str] = None,
|
||||
uploaded_by: Optional[str] = None,
|
||||
action_type: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Document]:
|
||||
query = db.query(Document)
|
||||
if filename:
|
||||
query = query.filter(
|
||||
Document.filename.ilike(f"%{filename}%"))
|
||||
if uploaded_by:
|
||||
query = query.filter(
|
||||
Document.uploaded_by == uploaded_by)
|
||||
if action_type:
|
||||
query = query.filter(
|
||||
Document.action_type == action_type)
|
||||
if status:
|
||||
query = query.filter(Document.status == status)
|
||||
if order_by == "desc":
|
||||
query = query.order_by(desc(getattr(Document, 'uploaded_at')))
|
||||
else:
|
||||
query = query.order_by(asc(getattr(Document, 'uploaded_at')))
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
documents = CRUDDocuments(Document)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Optional
|
||||
from fastapi_filter.contrib.sqlalchemy import Filter
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Boolean, event, select, func, update
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime
|
||||
|
||||
from private_gpt.users.db.base_class import Base
|
||||
from private_gpt.users.models.department import Department
|
||||
from private_gpt.users.models.document_department import document_department_association
|
||||
from sqlalchemy import Enum
|
||||
from enum import Enum as PythonEnum
|
||||
|
|
@ -90,3 +91,4 @@ class Document(Base):
|
|||
# update(Department).values(total_documents=total_documents).where(
|
||||
# Department.id == department_id)
|
||||
# )
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,6 @@ from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema, Profil
|
|||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate, DocumentCheckerUpdate, DocumentMakerCreate, DocumentDepartmentList, DocumentView
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate, DocumentCheckerUpdate, DocumentMakerCreate, DocumentDepartmentList, DocumentView, DocumentVerify, DocumentFilter
|
||||
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete
|
||||
from .audit import AuditBase, AuditCreate, AuditUpdate, Audit, GetAudit
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from fastapi import Form, UploadFile, File
|
||||
|
||||
from fastapi_filter.contrib.sqlalchemy import Filter
|
||||
|
||||
|
||||
class DocumentsBase(BaseModel):
|
||||
filename: str
|
||||
|
||||
|
|
@ -16,7 +20,6 @@ class DocumentCreate(DocumentsBase):
|
|||
class DocumentUpdate(BaseModel):
|
||||
id: int
|
||||
status: str
|
||||
is_enabled: bool
|
||||
|
||||
class DocumentEnable(BaseModel):
|
||||
id: int
|
||||
|
|
@ -58,7 +61,7 @@ class DocumentCheckerUpdate(BaseModel):
|
|||
is_enabled: bool
|
||||
verified_at: datetime
|
||||
verified_by: int
|
||||
|
||||
verified: bool
|
||||
|
||||
class DocumentDepartmentList(BaseModel):
|
||||
departments_ids: str = Form(...)
|
||||
|
|
@ -79,3 +82,25 @@ class DocumentView(BaseModel):
|
|||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class DocumentVerify(BaseModel):
|
||||
id: int
|
||||
filename: str
|
||||
uploaded_by: str
|
||||
uploaded_at: datetime
|
||||
departments: List[DepartmentList] = []
|
||||
status: str
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
|
||||
class DocumentFilter(BaseModel):
|
||||
filename: Optional[str] = None
|
||||
uploaded_by: Optional[str] = None
|
||||
action_type: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
order_by: Optional[str] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class UserBaseSchema(BaseModel):
|
|||
username: str
|
||||
company_id: int
|
||||
department_id: int
|
||||
checker: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ boto3 = {version ="^1.34.51", optional = true}
|
|||
gradio = {version ="^4.19.2", optional = true}
|
||||
aiofiles = "^23.2.1"
|
||||
timm = "^0.9.16"
|
||||
fastapi-filter = {extras = ["sqlalchemy"], version = "^1.1.0"}
|
||||
|
||||
[tool.poetry.extras]
|
||||
ui = ["gradio"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue