CLIP em Produção
Implementação completa de sistema CLIP para classificação zero-shot, busca semântica de imagens e extração de features. Inclui integração com Stable Diffusion e exemplos de fine-tuning para domínios específicos.
import torch
import clip
from PIL import Image
import numpy as np
from typing import List, Tuple
class CLIPSystem:
"""
Sistema CLIP completo para:
- Classificação zero-shot
- Busca semântica de imagens
- Extração de embeddings
"""
def __init__(self, model_name="ViT-L/14"):
"""
Args:
model_name: Modelo CLIP (ViT-B/32, ViT-L/14, ViT-L/14@336px)
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = clip.load(model_name, device=self.device)
# =====================
# CLASSIFICAÇÃO ZERO-SHOT
# =====================
def zero_shot_classify(
self,
image_path: str,
categories: List[str],
prompt_template: str = "a photo of a {}"
) -> List[Tuple[str, float]]:
"""
Classifica imagem em categorias arbitrárias
sem nenhum treinamento específico
Args:
image_path: Caminho para imagem
categories: Lista de categorias possíveis
prompt_template: Template para criar prompts
Returns:
Lista de (categoria, probabilidade) ordenada
"""
# Preprocessa imagem
image = self.preprocess(
Image.open(image_path)
).unsqueeze(0).to(self.device)
# Cria prompts de texto para cada categoria
text_prompts = [
prompt_template.format(category)
for category in categories
]
text_tokens = clip.tokenize(text_prompts).to(self.device)
# Calcula embeddings
with torch.no_grad():
image_features = self.model.encode_image(image)
text_features = self.model.encode_text(text_tokens)
# Normaliza embeddings
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Similaridade coseno
similarity = (image_features @ text_features.T).squeeze(0)
# Softmax para probabilidades
probs = similarity.softmax(dim=-1)
# Retorna ordenado por probabilidade
results = list(zip(categories, probs.cpu().numpy()))
return sorted(results, key=lambda x: x[1], reverse=True)
# =====================
# BUSCA SEMÂNTICA
# =====================
def build_image_index(self, image_paths: List[str]) -> np.ndarray:
"""
Constrói índice de embeddings para busca
Args:
image_paths: Lista de caminhos de imagens
Returns:
Array de embeddings normalizados
"""
embeddings = []
for path in image_paths:
image = self.preprocess(
Image.open(path)
).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model.encode_image(image)
features /= features.norm(dim=-1, keepdim=True)
embeddings.append(features.cpu().numpy())
return np.vstack(embeddings)
def search_by_text(
self,
query: str,
image_index: np.ndarray,
image_paths: List[str],
top_k: int = 10
) -> List[Tuple[str, float]]:
"""
Busca imagens por descrição textual
Args:
query: Descrição da imagem desejada
image_index: Embeddings das imagens
image_paths: Caminhos correspondentes
top_k: Número de resultados
Returns:
Lista de (caminho, similaridade)
"""
# Embed query
text_tokens = clip.tokenize([query]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Similaridade com todas imagens
similarities = (text_features.cpu().numpy() @ image_index.T).squeeze()
# Top-K resultados
top_indices = np.argsort(similarities)[::-1][:top_k]
return [
(image_paths[i], similarities[i])
for i in top_indices
]
def search_by_image(
self,
query_image_path: str,
image_index: np.ndarray,
image_paths: List[str],
top_k: int = 10
) -> List[Tuple[str, float]]:
"""
Busca imagens similares a uma imagem query
"""
# Embed query image
image = self.preprocess(
Image.open(query_image_path)
).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Similaridade
similarities = (image_features.cpu().numpy() @ image_index.T).squeeze()
# Top-K
top_indices = np.argsort(similarities)[::-1][:top_k]
return [
(image_paths[i], similarities[i])
for i in top_indices
]
# Integração com Stable Diffusion
class CLIPGuidedDiffusion:
"""
Usa CLIP para guiar geração de imagens
"""
def __init__(self):
from diffusers import StableDiffusionPipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
).to("cuda")
self.clip_system = CLIPSystem()
def generate_with_clip_guidance(
self,
prompt: str,
negative_prompt: str = "",
guidance_scale: float = 7.5,
num_images: int = 4
):
"""
Gera imagens e ranqueia por CLIP score
"""
# Gera múltiplas imagens
images = self.pipe(
prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images
).images
# Avalia cada imagem com CLIP
scores = []
for img in images:
# Salva temporariamente
img.save("/tmp/temp_clip.png")
# Classifica
result = self.clip_system.zero_shot_classify(
"/tmp/temp_clip.png",
[prompt, "random image"]
)
scores.append(result[0][1]) # Score do prompt
# Retorna ordenado por score
ranked = sorted(
zip(images, scores),
key=lambda x: x[1],
reverse=True
)
return ranked
# Fine-tuning para domínio específico
class CLIPFineTuner:
"""
Fine-tuna CLIP para domínio específico
mantendo capacidade zero-shot
"""
def __init__(self, base_model="ViT-B/32"):
self.device = "cuda"
self.model, self.preprocess = clip.load(base_model, device=self.device)
# Congela maior parte do modelo
for param in self.model.parameters():
param.requires_grad = False
# Descongela últimas camadas
for param in self.model.visual.transformer.resblocks[-2:].parameters():
param.requires_grad = True
for param in self.model.transformer.resblocks[-2:].parameters():
param.requires_grad = True
def fine_tune(self, dataset, epochs=10, lr=1e-6):
"""
Fine-tuna em dataset de domínio
Dataset format: [(image_path, text_description), ...]
"""
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=lr
)
for epoch in range(epochs):
total_loss = 0
for image_path, text in dataset:
# Processa inputs
image = self.preprocess(
Image.open(image_path)
).unsqueeze(0).to(self.device)
text_tokens = clip.tokenize([text]).to(self.device)
# Forward
image_features = self.model.encode_image(image)
text_features = self.model.encode_text(text_tokens)
# Normaliza
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Loss contrastivo
logits = (image_features @ text_features.T) * self.model.logit_scale.exp()
labels = torch.arange(len(logits)).to(self.device)
loss = (
torch.nn.functional.cross_entropy(logits, labels) +
torch.nn.functional.cross_entropy(logits.T, labels)
) / 2
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataset):.4f}")
# Uso
if __name__ == "__main__":
clip_system = CLIPSystem(model_name="ViT-L/14")
# Classificação zero-shot
print("=== Zero-shot Classification ===")
results = clip_system.zero_shot_classify(
"./my_image.jpg",
categories=["cat", "dog", "bird", "car", "airplane"],
prompt_template="a photo of a {}"
)
for category, prob in results:
print(f"{category}: {prob:.2%}")
# Busca semântica
print("\n=== Semantic Search ===")
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
index = clip_system.build_image_index(image_paths)
results = clip_system.search_by_text(
"a sunset over the ocean",
index,
image_paths
)
print("Results:", results)