diff --git a/visual_chatgpt.py b/visual_chatgpt.py index e9614894..2b356ea4 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -664,7 +664,7 @@ def __init__(self, device): "or perform segmentation on this image. " "The input to this tool should be a string, representing the image_path") def inference(self, inputs): - image = Image.open(inputs) + image = Image.open(inputs).convert("RGB") pixel_values = self.image_processor(image, return_tensors="pt").pixel_values with torch.no_grad(): outputs = self.image_segmentor(pixel_values)