Update dependencies. Remove custom gpt4all_j wrapper.

This commit is contained in:
Iván Martínez 2023-05-08 23:41:57 +02:00
parent 92244a90b4
commit bdd8c8748b
3 changed files with 23 additions and 170 deletions

View file

@ -1,8 +1,8 @@
from gpt4all_j import GPT4All_J
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All
def main():
# Load stored vectorstore
@ -12,14 +12,28 @@ def main():
retriever = db.as_retriever()
# Prepare the LLM
callbacks = [StreamingStdOutCallbackHandler()]
llm = GPT4All_J(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', callbacks=callbacks, verbose=False)
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
llm = GPT4All(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', backend='gptj', callbacks=callbacks, verbose=False)
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Interactive questions and answers
while True:
query = input("Enter a query: ")
query = input("\nEnter a query: ")
if query == "exit":
break
qa.run(query)
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
if __name__ == "__main__":
main()