题解 | #Vit中的patch embedding#

Vit中的patch embedding

https://www.nowcoder.com/practice/0c9f4697d96749198a60efeb06294001

题目链接

Vit中的patch embedding

题目描述

在 Vision Transformer (ViT) 模型中,输入图像首先会被分割成一系列固定大小的小块(patches)。每个 patch 随后被线性地映射到一个 embedding 向量。在这些 patch embeddings 的最前面,还会额外添加一个特殊的“分类 token”(classification token),用于最终的分类任务。

已知以下参数:

  • 图像的边长
  • 每个 patch 的边长
  • 图像的通道数
  • embedding 向量的维度

请计算并输出经过 patch embedding 操作后的最终形状。

计算公式:

  • (其中 是因为包含了分类 token)

输入描述: 一行四个整数,分别为

输出描述: 一行两个整数,分别为

说明:

  • 保证 可以被 整除。
  • 解题过程中不得使用任何深度学习框架。

解题思路

本题的核心是理解 ViT 模型中对输入图像的预处理过程,并根据给定的公式进行直接计算。

  1. 读取输入:程序需要读取四个整数:

  2. 计算每边的 patch 数量:首先,计算在图像的单条边上可以切分出多少个 patch。这个数量是

  3. 计算总的 patch 数量:由于图像是二维的,总的 patch 数量是每边 patch 数的平方,即

  4. 计算总的 token 数量:根据题目要求,在所有 patch embeddings 前面需要加上一个分类 token。因此,最终的 token 总数是

  5. 组合最终形状:输出的形状由两部分组成:计算出的 和输入的

值得注意的是,输入参数中的 在本题的计算中并未使用到,可以直接忽略。整个过程是一个简单的算术运算。

代码

#include <iostream>

using namespace std;

int main() {
    // C++ 中使用 long long 来处理可能的大整数,确保计算不会溢出
    long long img_size, patch_size, channels, embedding_dim;
    
    // 读取输入参数
    cin >> img_size >> patch_size >> channels >> embedding_dim;
    
    // 计算每条边上的 patch 数量
    long long patches_per_side = img_size / patch_size;
    
    // 计算总的 patch 数量 (二维),并加上分类 token
    long long token_count = patches_per_side * patches_per_side + 1;
    
    // 输出最终的 token 数量和 embedding 维度
    cout << token_count << " " << embedding_dim << endl;
    
    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        
        // Java 中使用 long 来处理可能的大整数
        long img_size = sc.nextLong();
        long patch_size = sc.nextLong();
        // channels 变量虽然被读取,但在计算中未使用
        long channels = sc.nextLong();
        long embedding_dim = sc.nextLong();
        
        // 计算每条边上的 patch 数量
        long patches_per_side = img_size / patch_size;
        
        // 计算总的 patch 数量,并加上分类 token
        long token_count = patches_per_side * patches_per_side + 1;
        
        // 输出结果
        System.out.println(token_count + " " + embedding_dim);
    }
}
# 读取一行输入,并将其解析为四个整数
img_size, patch_size, channels, embedding_dim = map(int, input().split())

# 计算每条边上的 patch 数量
patches_per_side = img_size // patch_size

# 计算总的 patch 数量 (二维),并加上分类 token
token_count = patches_per_side * patches_per_side + 1

# 使用 f-string 格式化输出
print(f"{token_count} {embedding_dim}")

算法及复杂度

  • 算法:本题主要考察对问题描述的理解和基本的算术运算。
  • 时间复杂度。代码只包含若干次读写和算术运算,其执行时间不随输入规模的变化而变化。
  • 空间复杂度。程序仅使用了几个变量来存储输入和计算结果,所需空间是常数级别的。
全部评论

相关推荐

点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务