-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathapp.py
68 lines (57 loc) · 2.76 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import requests
import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO
class InferlessPythonModel:
def initialize(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SamModel.from_pretrained("facebook/sam-vit-huge").to(self.device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
def process_image(self,masks,raw_image):
mask = masks[0].squeeze()
mask = mask[0].cpu().detach()
max_value = mask.max()
if max_value > 0: # Ensure division is safe
mask_normalized = mask / max_value
else:
mask_normalized = torch.zeros(mask.size())
# Convert mask_normalized to a numpy array if it's not already
mask_normalized_np = mask_normalized.cpu().detach().numpy()
color = np.array([30, 144, 255]) # RGB color for the mask
opacity = 0.6 # Opacity of the mask
image_array = np.array(raw_image)
image_array_normalized = image_array / 255.0
mask_rgb = np.zeros((*mask_normalized_np.shape, 3), dtype=np.float32)
for i in range(3): # Apply the color to the mask
mask_rgb[..., i] = mask_normalized_np * color[i] / 255
combined_image = (1 - opacity) * image_array_normalized + opacity * mask_rgb
combined_image = np.clip(combined_image, 0, 1) # Ensure the combined image is within the correct range
combined_image_uint8 = (combined_image * 255).astype(np.uint8)
img_to_save = Image.fromarray(combined_image_uint8)
img_to_save.save('combined_image.png')
buff = BytesIO()
img_to_save.save(buff, format="PNG")
img_str = base64.b64encode(buff.getvalue())
base64_string = img_str.decode('utf-8')
return img_str.decode('utf-8')
def infer(self,inputs):
input_points = [[inputs["input_points"]]]
img_url = inputs["image_url"]
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
inputs = self.processor(raw_image, input_points=input_points, return_tensors="pt").to(self.device)
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embeddings})
with torch.no_grad():
outputs = self.model(**inputs)
masks = self.processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
image_data = self.process_image(masks,raw_image)
return {"generated_image_base64": image_data}
def finalize(self,args):
pass