题解 | #Vit中的patch embedding#
Vit中的patch embedding
https://www.nowcoder.com/practice/0c9f4697d96749198a60efeb06294001
题目链接
题目描述
在 Vision Transformer (ViT) 模型中,输入图像首先会被分割成一系列固定大小的小块(patches)。每个 patch 随后被线性地映射到一个 embedding 向量。在这些 patch embeddings 的最前面,还会额外添加一个特殊的“分类 token”(classification token),用于最终的分类任务。
已知以下参数:
- 图像的边长
- 每个 patch 的边长
- 图像的通道数
- embedding 向量的维度
请计算并输出经过 patch embedding 操作后的最终形状。
计算公式:
(其中
是因为包含了分类 token)
输入描述:
一行四个整数,分别为 。
输出描述:
一行两个整数,分别为 。
说明:
- 保证
可以被
整除。
- 解题过程中不得使用任何深度学习框架。
解题思路
本题的核心是理解 ViT 模型中对输入图像的预处理过程,并根据给定的公式进行直接计算。
-
读取输入:程序需要读取四个整数:
。
-
计算每边的 patch 数量:首先,计算在图像的单条边上可以切分出多少个 patch。这个数量是
。
-
计算总的 patch 数量:由于图像是二维的,总的 patch 数量是每边 patch 数的平方,即
。
-
计算总的 token 数量:根据题目要求,在所有 patch embeddings 前面需要加上一个分类 token。因此,最终的 token 总数是
。
-
组合最终形状:输出的形状由两部分组成:计算出的
和输入的
。
值得注意的是,输入参数中的 在本题的计算中并未使用到,可以直接忽略。整个过程是一个简单的算术运算。
代码
#include <iostream>
using namespace std;
int main() {
// C++ 中使用 long long 来处理可能的大整数,确保计算不会溢出
long long img_size, patch_size, channels, embedding_dim;
// 读取输入参数
cin >> img_size >> patch_size >> channels >> embedding_dim;
// 计算每条边上的 patch 数量
long long patches_per_side = img_size / patch_size;
// 计算总的 patch 数量 (二维),并加上分类 token
long long token_count = patches_per_side * patches_per_side + 1;
// 输出最终的 token 数量和 embedding 维度
cout << token_count << " " << embedding_dim << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// Java 中使用 long 来处理可能的大整数
long img_size = sc.nextLong();
long patch_size = sc.nextLong();
// channels 变量虽然被读取,但在计算中未使用
long channels = sc.nextLong();
long embedding_dim = sc.nextLong();
// 计算每条边上的 patch 数量
long patches_per_side = img_size / patch_size;
// 计算总的 patch 数量,并加上分类 token
long token_count = patches_per_side * patches_per_side + 1;
// 输出结果
System.out.println(token_count + " " + embedding_dim);
}
}
# 读取一行输入,并将其解析为四个整数
img_size, patch_size, channels, embedding_dim = map(int, input().split())
# 计算每条边上的 patch 数量
patches_per_side = img_size // patch_size
# 计算总的 patch 数量 (二维),并加上分类 token
token_count = patches_per_side * patches_per_side + 1
# 使用 f-string 格式化输出
print(f"{token_count} {embedding_dim}")
算法及复杂度
- 算法:本题主要考察对问题描述的理解和基本的算术运算。
- 时间复杂度:
。代码只包含若干次读写和算术运算,其执行时间不随输入规模的变化而变化。
- 空间复杂度:
。程序仅使用了几个变量来存储输入和计算结果,所需空间是常数级别的。
