forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvision_transformer.py
More file actions
65 lines (49 loc) · 1.83 KB
/
vision_transformer.py
File metadata and controls
65 lines (49 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Vision Transformer (ViT) Module
================================
Classify images using a pretrained Vision Transformer (ViT)
from Hugging Face Transformers.
Can be used as a demo or imported in other scripts.
Source:
https://huggingface.co/docs/transformers/model_doc/vit
"""
try:
import requests
import torch
from io import BytesIO
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
except ImportError as e:
raise ImportError(
"This module requires 'torch', 'transformers', 'PIL', and 'requests'. "
"Install them with: pip install torch transformers pillow requests"
) from e
def classify_image(image: Image.Image) -> str:
"""Classify a PIL image using pretrained ViT."""
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
def demo(url: str = None) -> None:
"""
Run a demo using a sample image or provided URL.
Args:
url (str): URL of the image. If None, uses a default cat image.
"""
if url is None:
url = "https://images.unsplash.com/photo-1592194996308-7b43878e84a6" # default example image
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
except Exception as e:
print(f"Failed to load image from {url}. Error: {e}")
return
label = classify_image(image)
print(f"Predicted label: {label}")
if __name__ == "__main__":
demo()