Gradio 5.0 版本介绍

阅读更多
Gradio logo
  1. 其他教程
  2. 使用标记功能

使用标记功能

简介

当您演示机器学习模型时,您可能希望从尝试该模型的用户那里收集数据,特别是模型行为不符合预期的数据点。捕获这些“困难”数据点非常有价值,因为它可以让您改进机器学习模型,使其更可靠和更强大。

Gradio 通过在每个 Interface 中包含一个 Flag 按钮,简化了此数据的收集。这允许用户或测试人员轻松地将数据发送回运行演示的机器。在本指南中,我们将更详细地讨论如何使用标记功能,包括 gradio.Interfacegradio.Blocks

gradio.Interface 中的 Flag 按钮

使用 Gradio 的 Interface 进行标记非常容易。默认情况下,在输出组件下方,有一个标记为 Flag 的按钮。当测试您模型的用户看到有趣的输出输入时,他们可以单击标记按钮将输入和输出数据发送回运行演示的机器。 样本将保存到 CSV 日志文件(默认情况下)。如果演示涉及图像、音频、视频或其他类型的文件,这些文件将分别保存在并行目录中,并且这些文件的路径将保存在 CSV 文件中。

gradio.Interface 中有 四个参数 控制标记的工作方式。我们将更详细地介绍它们。

  • flagging_mode:此参数可以设置为 "manual"(默认)、"auto""never"
    • manual:用户将看到一个标记按钮,并且仅当单击该按钮时才标记样本。
    • auto:用户不会看到标记按钮,但每个样本都会自动标记。
    • never:用户不会看到标记按钮,并且不会标记任何样本。
  • flagging_options:此参数可以是 None(默认)或字符串列表。
    • 如果为 None,则用户只需单击 Flag 按钮,并且不显示其他选项。
    • 如果提供字符串列表,则用户会看到多个按钮,每个按钮对应于提供的字符串之一。例如,如果此参数的值为 ["Incorrect", "Ambiguous"],则会出现标记为 标记为不正确标记为模糊 的按钮。 这仅在 flagging_mode"manual" 时适用。
    • 所选选项将与输入和输出一起记录。
  • flagging_dir:此参数接受一个字符串。
    • 它表示存储标记数据的目录的名称。
  • flagging_callback:此参数接受 FlaggingCallback 类的子类的实例
    • 使用此参数允许您编写在单击标记按钮时运行的自定义代码
    • 默认情况下,此项设置为 gr.JSONLogger 的实例

标记数据会发生什么?

flagging_dir 参数提供的目录中,JSON 文件将记录标记的数据。

以下是一个示例:下面的代码创建了嵌入在下方的计算器界面

import gradio as gr


def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        return num1 / num2


iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    allow_flagging="manual"
)

iface.launch()

当您单击上面的标记按钮时,启动界面的目录将包含一个新的标记子文件夹,其中包含一个 csv 文件。此 csv 文件包含所有标记的数据。

+-- flagged/
|   +-- logs.csv

flagged/logs.csv

num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542

如果界面涉及文件数据,例如图像和音频组件,则将创建文件夹来存储这些标记的数据。例如,image 输入到 image 输出界面将创建以下结构。

+-- flagged/
|   +-- logs.csv
|   +-- image/
|   |   +-- 0.png
|   |   +-- 1.png
|   +-- Output/
|   |   +-- 0.png
|   |   +-- 1.png

flagged/logs.csv

im,Output timestamp
im/0.png,Output/0.png,2022-02-04 19:49:58.026963
im/1.png,Output/1.png,2022-02-02 10:40:51.093412

如果您希望用户提供标记的原因,您可以将字符串列表传递给 Interface 的 flagging_options 参数。 用户在标记时必须选择这些选项之一,并且该选项将保存为 CSV 的附加列。

如果我们回到计算器示例,以下代码将创建嵌入在下方的界面。

iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=["wrong sign", "off by one", "other"]
)

iface.launch()

当用户单击标记按钮时,csv 文件现在将包含一列,指示所选选项。

flagged/logs.csv

num1,operation,num2,Output,flag,timestamp
5,add,7,-12,wrong sign,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,off by one,2022-02-04 11:42:32.062512

使用 Blocks 进行标记

如果您正在使用 gradio.Blocks 怎么办? 一方面,Blocks 提供了更大的灵活性——您可以编写任何您想在单击按钮时运行的 Python 代码,并使用 Blocks 中的内置事件进行分配。

同时,您可能希望使用现有的 FlaggingCallback 来避免编写额外的代码。 这需要两个步骤

  1. 您必须在首次标记数据之前在代码中的某处运行回调的 .setup()
  2. 当单击标记按钮时,然后您触发回调的 .flag() 方法,确保正确收集参数并禁用典型的预处理。

这是一个带有图像棕褐色滤镜 Blocks 演示的示例,该演示允许您使用默认的 CSVLogger 标记数据

import numpy as np
import gradio as gr

def sepia(input_img, strength):
    sepia_filter = strength * np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    ) + (1-strength) * np.identity(3)
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

callback = gr.CSVLogger()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            img_input = gr.Image()
            strength = gr.Slider(0, 1, 0.5)
        img_output = gr.Image()
    with gr.Row():
        btn = gr.Button("Flag")

    # This needs to be called at some point prior to the first call to callback.flag()
    callback.setup([img_input, strength, img_output], "flagged_data_points")

    img_input.change(sepia, [img_input, strength], img_output)
    strength.change(sepia, [img_input, strength], img_output)

    # We can choose which components to flag -- in this case, we'll flag all of them
    btn.click(lambda *args: callback.flag(list(args)), [img_input, strength, img_output], None, preprocess=False)

demo.launch()

隐私

重要提示:请确保您的用户了解他们提交的数据何时被保存,以及您计划如何处理这些数据。当您使用 flagging_mode=auto 时(当通过演示提交的所有数据都被标记时),这一点尤其重要

就这样!祝您构建愉快 :)