[Gradio]自动生成交互式用户界面

Gradio是一个开源Python包,能够自动构建交互式用户界面,帮助工程师快速对外展示算法实现效果。

概述

只需要在接口函数中实现算法功能,Gradio就能够生成一个交互式WEB应用,快速展现算法实现效果。

安装

Gradio要求Python版本在3.8及以上,采用pip3安装即可:

1
pip3 install gradio -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com

官方DEMO:打印文字

编写main .py,实现算法greet(name, intensity):能够打印文字。

1
2
3
4
5
6
7
8
9
10
11
12
import gradio as gr

def greet(name, intensity):
return "Hello " * intensity + name + "!"

demo = gr.Interface(
fn=greet,
inputs=["text", "slider"],
outputs=["text"],
)

demo.launch()

执行该程序,本地监听的默认端口号是7860

1
2
3
4
$ python3 main.py 
Running on local URL: http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.

在线部署

Gradio Only

官方提供了快速在线部署的功能,设置启动函数launch()的参数share=True

1
2
# demo.launch()
demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀

执行后Gradio会提供临时的公网访问链接,实际操作下来发现无法正常访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
$ python3 main.py 
Running on local URL: http://127.0.0.1:7860

Could not create share link. Missing file: /home/zj/anaconda3/envs/yolov5/lib/python3.8/site-packages/gradio/frpc_linux_amd64_v0.2.

Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:

1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64
2. Rename the downloaded file to: frpc_linux_amd64_v0.2
3. Move the file to this location: /home/zj/anaconda3/envs/yolov5/lib/python3.8/site-packages/gradio


$ python3 main.py
Running on local URL: http://127.0.0.1:7860

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.

Gradio + Nginx

尝试在云服务器上搭建Gradio。首先在云服务器上启动Gradio程序,

注意一,需要设置服务器地址为0.0.0.0而不是localhost

注意二:可以指定Gradio在云服务器本地监听的端口号

注意三:设置参数root_path,和Nginx配置文件子路径保持一致

1
demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/gradio/demo/")

启动程序后,就可以通过端口号7860进行远程访问

1
2
3
4
# python3 main.py 
Running on local URL: http://0.0.0.0:7860

To create a public link, set `share=True` in `launch()`.

结合Nginx进行反向代理,修改配置文件

1
2
3
4
5
6
7
8
9
10
location /gradio/demo/ {  # Change this if you'd like to server your Gradio app on a different path
#proxy_pass http://127.0.0.1:7860/; # Change this if your Gradio app will be running on a different port
proxy_pass http://<服务器内网IP>:7860/;
proxy_buffering off;
proxy_redirect off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}

重新启动Nginx,就可以通过网址进行访问: https://blog.zjykzj.cn/gradio/demo/

图像分类

Gradio针对图像分类定制了结果页面,结合标签值和分类概率进行展示。采用Gradio+Pytorch+Torchvision,实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-

"""
@date: 2024/1/28 下午8:27
@file: classify.py
@author: zj
@description:
"""
import json
# import requests
from PIL import Image

import torch
import torchvision
from torchvision import transforms

import gradio as gr

# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
model = torchvision.models.resnet18(pretrained=True).eval()


def get_labels():
# Download human-readable labels for ImageNet.
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")

with open("./files/imagenet_labels.json", 'r') as f:
json_data = json.load(f)

# print(type(json_data), len(json_data))
labels = json_data

return labels


labels = get_labels()


def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences


if __name__ == '__main__':
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=["./images/lion.jpg", "./images/cheetah1.jpg"]).launch()

修改成Gradio+ONNX+ONNXRuntime,部署到云服务器(https://blog.zjykzj.cn/gradio/classify/)。实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-

"""
@date: 2024/1/28 下午8:27
@file: classify.py
@author: zj
@description:
"""

import os
import json

import cv2
import numpy as np

import gradio as gr


def softmax(x):
exps = np.exp(x - np.max(x)) # 防止溢出
return exps / np.sum(exps)


class Classifier:

def __init__(self, weight: str = 'resnet18.onnx'):
super().__init__()
self.load_onnx(weight)

def load_onnx(self, weight: str):
assert os.path.isfile(weight), weight

print(f'Loading {weight} for ONNX Runtime inference...')
import onnxruntime
providers = ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(weight, providers=providers)

input_names = [x.name for x in session.get_inputs()]
print(f"input_names: {input_names}")
output_names = [x.name for x in session.get_outputs()]
print(f"output_names: {output_names}")
metadata = session.get_modelmeta().custom_metadata_map # metadata
print(f"metadata: {metadata}")

self.session = session
self.output_names = output_names
self.dtype = np.float32
print(f"Init Done. Work with {self.dtype}")

def classify(self, pil_im):
im = np.array(pil_im, dtype=self.dtype)
im = cv2.resize(im, (224, 224))
im = im.transpose(2, 0, 1) / 255
im = im[None]
print(f"im shape: {im.shape}")

prediction = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})[0]
print(f"prediction shape: {prediction.shape}")
probs = softmax(prediction[0])
confidences = {labels[i]: float(probs[i]) for i in range(1000)}
return confidences


model = Classifier(weight="./resnet18_pytorch.onnx")


def get_labels():
with open("./files/imagenet_labels.json", 'r') as f:
json_data = json.load(f)

# print(type(json_data), len(json_data))
labels = json_data

return labels


labels = get_labels()


def predict(inp):
return model.classify(inp)


if __name__ == '__main__':
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=["./images/lion.jpg", "./images/cheetah1.jpg"]).launch(server_name="0.0.0.0",
server_port=7861,
root_path="/gradio/classify")

部署到云端服务器,修改Nginx配置文件,增加一个location

1
2
3
4
5
6
7
8
9
location /gradio/classify/ {  # Change this if you'd like to server your Gradio app on a different path
proxy_pass http://<服务器内网IP>:7861/;
proxy_buffering off;
proxy_redirect off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}

图像检测

使用Gradio进行图像检测算法展示,采用Gradio+Pytorch,实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-

"""
@Time : 2024/1/29 20:18
@File : gradio_det.py
@Author : zj
@Description:
"""

import torch

import numpy as np
import gradio as gr

import cv2
import random
import colorsys

MODEL_NAMES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck',
8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear',
22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase',
29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle',
40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple',
48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut',
55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet',
62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase',
76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

CLASSES_NAME = [item[1] for item in MODEL_NAMES.items()]


def gen_colors(classes):
"""
generate unique hues for each class and convert to bgr
classes -- list -- class names (80 for coco dataset)
-> list
"""
hsvs = []
for x in range(len(classes)):
hsvs.append([float(x) / len(classes), 1., 0.7])
random.seed(1234)
random.shuffle(hsvs)
rgbs = []
for hsv in hsvs:
h, s, v = hsv
rgb = colorsys.hsv_to_rgb(h, s, v)
rgbs.append(rgb)
bgrs = []
for rgb in rgbs:
bgr = (int(rgb[2] * 255), int(rgb[1] * 255), int(rgb[0] * 255))
bgrs.append(bgr)
return bgrs


def draw_results(img, preds, CLASSES_NAME, is_xyxy=True):
CLASSES_COLOR = gen_colors(CLASSES_NAME)

overlay = img.copy()
if len(preds) != 0:
for pred in preds:
box, conf, cls = pred[:4], pred[4], pred[5]

if is_xyxy:
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
else:
x1, y1, box_w, box_h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
x2 = x1 + box_w
y2 = y1 + box_h
cls_name = CLASSES_NAME[int(cls)]
color = CLASSES_COLOR[int(cls)]
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, thickness=2, lineType=cv2.LINE_AA)
cv2.putText(overlay, '%s %.3f' % (cls_name, conf), org=(x1, int(y1 - 10)),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1, color=color)
return overlay


# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
# model = torch.hub.load("/home/zj/.cache/torch/hub/ultralytics_yolov5_master", 'yolov5s', source="local")


def predict(inp):
results = model(inp.convert("RGB"))

overlay = draw_results(np.array(inp), results.pred[0], CLASSES_NAME, is_xyxy=True)
# cv2.imwrite("overlay.jpg", overlay)
return overlay


if __name__ == '__main__':
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="image",
examples=["./images/bus.jpg", "./images/zidane.jpg"]).launch()

修改成Gradio+ONNX+ONNXRuntime,部署到云服务器(https://blog.zjykzj.cn/gradio/detect/)。实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-

"""
@date: 2024/1/29 下午9:01
@file: gradio_det.py
@author: zj
@description:
"""

import numpy as np
import gradio as gr

from general import CLASSES_NAME
from yolov5_util import draw_results
from yolov5_runtime_w_numpy import YOLOv5Runtime

model = YOLOv5Runtime("./yolov5s.onnx")


def predict(inp):
boxes, confs, cls_ids = model.detect(np.array(inp))

overlay = draw_results(np.array(inp), boxes, confs, cls_ids, CLASSES_NAME, is_xyxy=True)
# cv2.imwrite("overlay.jpg", overlay)
return overlay


if __name__ == '__main__':
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="image",
examples=["./images/bus.jpg", "./images/zidane.jpg"]).launch(server_name="0.0.0.0", server_port=7862,
root_path="/gradio/detect")

部署到云端服务器,修改Nginx配置文件,增加一个location

1
2
3
4
5
6
7
8
9
location /gradio/detect/ {  # Change this if you'd like to server your Gradio app on a different path
proxy_pass http://<服务器内网IP>:7862/;
proxy_buffering off;
proxy_redirect off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}