Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3857e22
added pdf context support
andrewwan0131 Dec 7, 2024
06d056b
These changes are in response to PR comments
andrewwan0131 Dec 8, 2024
cc66890
These changes are in response to PR comments
andrewwan0131 Dec 8, 2024
85767e5
Changed file detection to magic numbers and removed unnecessary libra…
andrewwan0131 Dec 8, 2024
c49344f
switch to using environment variable
CodingWithTim Dec 26, 2024
afbf7e6
new architecture and bug fixes
CodingWithTim Dec 30, 2024
8527b02
fix format
CodingWithTim Dec 30, 2024
8da825b
improve UI and efficiency
CodingWithTim Dec 30, 2024
b59cea8
fix formatting
CodingWithTim Dec 30, 2024
d5efd2c
fix first page only parsing issue
CodingWithTim Dec 30, 2024
f2905b9
fix first page only parsing issue
CodingWithTim Dec 30, 2024
e7ab73f
additional improvements
CodingWithTim Dec 30, 2024
f1c6185
add multilingual support
CodingWithTim Jan 2, 2025
f7e92e1
support google cloud storage
CodingWithTim Jan 4, 2025
0daef32
fix format
CodingWithTim Jan 4, 2025
32c6724
add pdf maximum page limit
CodingWithTim Jan 4, 2025
2cb0937
remove language detection
CodingWithTim Jan 5, 2025
e4c0f3b
fix format
CodingWithTim Jan 5, 2025
5c52665
support multimodal pdfchat and switch to marker pdf
CodingWithTim Jan 6, 2025
61284e0
switch to package implementation of 'is_image'
CodingWithTim Jan 6, 2025
586a2f6
flexible state variable for pdf_id
CodingWithTim Jan 6, 2025
2ea729c
fix error
CodingWithTim Jan 6, 2025
f2c4d64
Marker API Implemented + Updated Llama code if ever needed
Jan 31, 2025
cf9b408
Content Moderation implemented + couple of latency improvements
Feb 11, 2025
2bf158c
fixed bug where text moderation wasn't being flagged
andrewwan0131 Feb 12, 2025
06110d2
fixed bug where text moderation wasn't being flagged
andrewwan0131 Feb 12, 2025
34c7a8e
added image_resize functionality for image moderation
yixin-huang1 Feb 13, 2025
0955a76
minor text fix
andrewwan0131 Feb 13, 2025
2878d3c
fixed formatting
andrewwan0131 Feb 14, 2025
701f7c5
applied black formatting py3.10
andrewwan0131 Feb 14, 2025
ab2443c
fixed black version
andrewwan0131 Feb 14, 2025
89743ad
Revert "minor text fix"
yixin-huang1 Feb 16, 2025
63cb555
revert + fix some PR issues
yixin-huang1 Feb 16, 2025
bb24800
revert the spacing changes
yixin-huang1 Feb 16, 2025
4902d71
Update setup_pdfchat.sh
yixin-huang1 Feb 16, 2025
3528919
fix formatting
CodingWithTim Feb 19, 2025
bac54f0
FIX: add missing pdf_id
CodingWithTim Feb 19, 2025
1c6f911
FIX: add_text logic and make cleaner
CodingWithTim Feb 19, 2025
85ed193
FIX: small bug with push to gc
CodingWithTim Feb 19, 2025
0919376
Updated PDF Character Length Limit
Feb 20, 2025
a4b3b6b
Fixed bug with character limits
Feb 21, 2025
5a314de
Fixed bug with character limits
Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IMAGE_MODERATION_MSG = (
"$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES."
)
PDF_MODERATION_MSG = "$MODERATION$ YOUR PDF VIOLATES OUR CONTENT MODERATION GUIDELINES."
MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
Expand All @@ -37,8 +38,13 @@
BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int(
os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 30000)
)
PDF_CHAR_LEN_LIMIT = 320000
# Maximum conversation turns
CONVERSATION_TURN_LIMIT = 50
# Maximum PDF Page Limit
PDF_PAGE_LIMIT = 50
PDF_LIMIT_MSG = f"YOU HAVE REACHED THE MAXIMUM PDF PAGE LIMIT ({PDF_PAGE_LIMIT} PAGES). PLEASE UPLOAD A SMALLER DOCUMENT."
PDF_CHAR_LIMIT_MSG = f"YOU HAVE REACHED THE MAXIMUM PDF CHARACTER LIMIT ({PDF_CHAR_LEN_LIMIT} CHARACTERS). PLEASE UPLOAD A SMALLER DOCUMENT OR PROMPT."
# Session expiration time
SESSION_EXPIRATION_TIME = 3600
# The output dir of log files
Expand Down
33 changes: 27 additions & 6 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,39 @@ def update_last_message(self, message: str):
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
from fastchat.serve.vision.image import ImageFormat
import re

ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
image = images[0] # Only one image on gradio at one time
if image.image_format == ImageFormat.URL:
img_str = f'<img src="{image.url}" alt="user upload image" />'
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

pattern = re.compile("!\[\]\(_page_\d_Figure_\d\.jpeg\)")
embed_locations = pattern.findall(msg)

pdfchat = False
for i, embed_str in enumerate(embed_locations):
if i >= len(images):
break

image = images[i]
msg = msg.replace(
embed_str,
f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="document image" />',
)
pdfchat = True

if not pdfchat:
# vision arena only supports one image on gradio at one time
image = images[0]
if image.image_format == ImageFormat.URL:
img_str = (
f'<img src="{image.url}" alt="user upload image" />'
)
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
Expand Down
5 changes: 4 additions & 1 deletion fastchat/model/apply_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):

split_size = 4 * GB

with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
with (
tempfile.TemporaryDirectory() as tmp_base_path,
tempfile.TemporaryDirectory() as tmp_delta_path,
):
print(f"Split files for the base model to {tmp_base_path}")
split_files(base_model_path, tmp_base_path, split_size)
print(f"Split files for the delta weights to {tmp_delta_path}")
Expand Down
146 changes: 137 additions & 9 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from gradio.data_classes import FileData
import numpy as np

from io import BytesIO
import base64

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -74,16 +77,29 @@ def get_vqa_sample():
return (res, path)


def is_pdf(file_path):
try:
with open(file_path, "rb") as file:
header = file.read(5) # Read the first 5 bytes
return header == b"%PDF-"
except Exception as e:
print(f"Error: {e}")
return False


def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
import filetype

files = textbox["files"]
if len(files) == 0:
return invisible_image_column
elif len(images) > 1:
elif len(files) > 1:
gr.Warning(
"We only support single image conversations. Please start a new round if you would like to chat using this image."
"We only support single image or document conversations. Please start a new round if you would like to chat using this image or document."
)

return visible_image_column
elif filetype.is_image(files[0]):
return visible_image_column
return invisible_image_column


def set_invisible_image():
Expand Down Expand Up @@ -161,11 +177,96 @@ def clear_history_example(request: gr.Request):
) * 5


# TODO(Chris): At some point, we would like this to be a live-reporting feature.
def report_csam_image(state, image):
pass


def wrap_pdfchat_query(query, document):
reformatted_query_context = (
f"Answer the user query given the context.\n"
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{document}\n\n"
f"</details>"
f"\n\n[USER QUERY]\n\n{query}"
)

return reformatted_query_context


PDFPARSE_MAX_RETRY = 2
PDFPARSE_SUPPORTED_LANGS = {
"English": "en",
"Chinese": "zh",
"Russian": "ru",
"Spanish": "es",
"Japanese": "ja",
"Korean": "ko",
"French": "fr",
"German": "de",
"Vietnamese": "vi",
}
MARKER_PDFPARSE_CONFIG = {
"output_format": "markdown",
"languages": ",".join(PDFPARSE_SUPPORTED_LANGS.values()),
}


def convert_base64_to_pil_image(b64_string):
from PIL import Image
import numpy as np

image_data = np.frombuffer(base64.b64decode(b64_string), dtype=np.uint8)
image_bytes = BytesIO(image_data)
image = Image.open(image_bytes)

return image


def batch_convert_base64_to_images(base64_dict):
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as executor:
return list(executor.map(convert_base64_to_pil_image, base64_dict.values()))


def parse_pdf(file_path):
import requests

url = "https://www.datalab.to/api/v1/marker"

form_data = {
"file": ("test.pdf", open(file_path, "rb"), "application/pdf"),
"force_ocr": (None, False),
"paginate": (None, False),
"output_format": (None, "markdown"),
"use_llm": (None, True),
"strip_existing_ocr": (None, False),
"disable_image_extraction": (None, False),
}

headers = {"X-Api-Key": str(os.getenv("MARKER_API_KEY"))}
response = requests.post(url, files=form_data, headers=headers)
data = response.json()

max_polls = 300
check_url = data["request_check_url"]

for i in range(max_polls):
time.sleep(2)
response = requests.get(check_url, headers=headers)
data = response.json()

if data["status"] == "complete":
break

output_md = data["markdown"]
output_images = batch_convert_base64_to_images(data["images"])

return output_md, output_images


def _prepare_text_with_image(state, text, images, csam_flag):
if len(images) > 0:
if len(state.conv.get_images()) > 0:
Expand All @@ -177,7 +278,20 @@ def _prepare_text_with_image(state, text, images, csam_flag):
return text


# NOTE(chris): take multiple images later on
def _prepare_text_with_pdf(text, pdfs):
if len(pdfs) > 0:
parsed_text, imgs = parse_pdf(pdfs[0])
print("Document processed")
wrapped_text = wrap_pdfchat_query(text, parsed_text)

imgs = convert_pdf_images_to_conversation_format(imgs)

if len(imgs) > 0:
return wrapped_text, imgs
return wrapped_text, []
return text, []


def convert_images_to_conversation_format(images):
import base64

Expand All @@ -191,6 +305,20 @@ def convert_images_to_conversation_format(images):
return conv_images


def convert_pdf_images_to_conversation_format(images):
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
conv_images = []
if len(images) > 0:
for img in images:
# pdf parser returns a PIL image object instead of path
conv_images.append(
Image(url="").to_conversation_format(
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB, pil_img=img
)
)
return conv_images


def moderate_input(state, text, all_conv_text, model_list, images, ip):
text_flagged = moderation_filter(all_conv_text, model_list)
# flagged = moderation_filter(text, [state.model_name])
Expand All @@ -213,7 +341,7 @@ def moderate_input(state, text, all_conv_text, model_list, images, ip):
state.has_csam_image = True
report_csam_image(state, images[0])

return text, image_flagged, csam_flagged
return text, text_flagged, image_flagged, csam_flagged


def add_text(
Expand Down
Loading