forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvision_transformer.py
More file actions
58 lines (43 loc) · 1.61 KB
/
vision_transformer.py
File metadata and controls
58 lines (43 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Vision Transformer (ViT) Module
================================
Classify images using a pretrained Vision Transformer (ViT)
from Hugging Face Transformers.
"""
from io import BytesIO
from typing import Optional
import requests
import torch
from PIL import Image, UnidentifiedImageError
from transformers import ViTForImageClassification, ViTImageProcessor
def classify_image(image: Image.Image) -> str:
"""Classify a PIL image using pretrained ViT."""
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
def demo(url: Optional[str] = None) -> None:
"""
Run a demo using a sample image or provided URL.
Args:
url (Optional[str]): URL of the image. If None, uses default cat image.
"""
if url is None:
url = (
"https://images.unsplash.com/photo-1592194996308-7b43878e84a6"
) # default example image
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
except (requests.RequestException, UnidentifiedImageError) as e:
print(f"Failed to load image from {url}. Error: {e}")
return
label = classify_image(image)
print(f"Predicted label: {label}")
if __name__ == "__main__":
demo()