【Pytorch基础(3)】张量的拼接,拆分与统计

news/2023/6/9 19:15:26

一、张量的拼接

张量的拼接主要通过cat()和stack()函数实现。其中torch.cat([a, b], dim=n)是在n维度上进行两个张量的拼接,其参数n的含义代表要进行拼接操作的维度,a和b则代表要拼接的张量。在使用cat()方法时需要注意的是两个张量除了拼接的维度可以不同,其他的维度必须相同,否则会报错。示例如下:
Statistics about scores
a [class1-3, students, scores]
b [class4-9, students, scores]

import torcha = torch.rand(3, 32, 8)
b = torch.rand(6, 32, 8)print(a.shape)
print(b.shape)
print(torch.cat([a, b], dim=0).shape)# output
# torch.Size([3, 32, 8])
# torch.Size([6, 32, 8])
# torch.Size([9, 32, 8])

torch.stack([a, b], dim=n)是拼接两个张量a,b时,在维度n之前生成一个新的维度。注意,stack()方法对于带拼接的两个张量形状要求更加严格,具体来说当使用stack()方法时,要保证拼接的两个张量形状是相同的,否则会报错,示例如下:

import torcha = torch.rand(3, 32, 8)
b = torch.rand(6, 32, 8) 
c = torch.rand(3, 32, 8)print(torch.stack([a, c], dim=0).shape)
print(torch.stack([a, b], dim=0).shape)# output
# torch.Size([2, 3, 32, 8])
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# Input In [35], in <cell line: 8>()
#       5 c = torch.rand(3, 32, 8)
#       7 print(torch.stack([a, c], dim=0).shape)
# ----> 8 print(torch.stack([a, b], dim=0).shape)# RuntimeError: stack expects each tensor to be equal size, 
# but got [3, 32, 8] at entry 0 and [6, 32, 8] at entry 1

二、张量的拆分

张量的拆分主要通过split()和chunk()函数实现。其中split()是在某维度上按照定义的间隔进行维度拆分的,方法的格式为torch.split(要拆掉的张量,拆分时的间隔数,要拆分的维度索引) ,拆分后的结果将以列表的形式进行返回。示例如下:

import torcha = torch.rand(5, 32, 8)
# 对张量a中的第0维以间隔2进行拆分
b = torch.split(a, 2, 0)print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape)
print(b[2].shape)# output
# torch.Size([5, 32, 8])
# 3
# torch.Size([2, 32, 8])
# torch.Size([2, 32, 8])
# torch.Size([1, 32, 8])

至于chunk(),则是在某维度上按照定义的数量进行维度拆分的,方法的格式为torch.chunk(要拆掉的张量,拆分后的数量,要拆分的维度索引) ,拆分后的结果将以列表的形式进行返回。示例如下:

import torcha = torch.rand(5, 32, 8)
# 对张量a中的第1维进行拆分,拆分后可得到两个子集
b = torch.chunk(a, 2, 1)print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape) # output
# torch.Size([5, 32, 8])
# 2
# torch.Size([5, 16, 8])
# torch.Size([5, 16, 8])

二、张量的统计运算

pytorch中,常用的张量的取整方法有五种,分别是:

.floor() 向下取整
.ceil() 向上取整
.round() 四舍五入
.trunc() 裁剪出整数部分
.frac() 裁剪出小数部分
示例如下:

import torcha = torch.tensor(3.1415926)print(a.floor())
print(a.ceil())
print(a.round())
print(a.trunc())
print(a.frac())# output
# tensor(3.)
# tensor(4.)
# tensor(3.)
# tensor(3.)
# tensor(0.1416)

pytorch中,常用的张量统计方法有五种,分别是:

.mean() 求均值
.sum() 求和
.max() 求最大值
.min() 求最小值
.prod() 求乘积
示例如下:

import torcha = torch.tensor([1., 2., 3., 4., 5., 6., 7.])print(a.mean())
print(a.sum())
print(a.max())
print(a.min())
print(a.prod())# output
# tensor(4.)
# tensor(28.)
# tensor(7.)
# tensor(1.)
# tensor(5040.)

pytorch中,我们还可以取到一个张量最大值或最小值的索引,使用的方法是argmin()和argmax()。这个在做识别任务时非常常见,后续会讲到。

示例如下:

import torcha = torch.tensor([1., 2., 3., 4., 5., 6., 7.])print(a.argmin())
print(a.argmax()) # output
# tensor(0)
# tensor(6)

pytorch中,我们可以使用的方法torch.eq()和torch.equal()方法来判断两个张量是否相等。两者接受的参数都是两个张量,其中eq()方法的返回值是按元素位置返回True或False,False代表不等,True代表相等。而equal()方法的返回值是True或False,当两个张量完全一样时,才会返回True,不然返回False。代码示例如下:

import torcha = torch.ones(3,3)
b = torch.eye(3,3)print(torch.eq(a, b))
print(torch.equal(a, b))# output
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
# False

三、torch.eye()函数

函数原型:

result = torch.eye(n,m=None,out=None)

参数解释:
n:行数
m:列数
out:输出类型
例:

c = torch.eye(3)
print(c)
print(type(c))

输出

tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
<class ‘torch.Tensor’>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.exyb.cn/news/show-4564741.html

如若内容造成侵权/违法违规/事实不符,请联系郑州代理记账网进行投诉反馈,一经查实,立即删除!

相关文章

理解本真的REST架构风格

理解本真的REST架构风格 引子 在移动互联网、云计算迅猛发展的今天&#xff0c;作为一名Web开发者&#xff0c;如果您还没听说过“REST”这个buzzword&#xff0c;显然已经落伍了。夸张点说&#xff0c;甚至“出了门都不好意思跟别人打招呼”。尽管如此&#xff0c;对于REST这个…

php 百度语音识别 REST API demo

1&#xff0c;首先打开百度语音识别官网&#xff0c;注册一个账户成为开发者&#xff0c;接着创建一个应用&#xff0c;下载百度提供源代码 。 下载地址&#xff1a; http://yuyin.baidu.com/sdk/ 官方文档地址&#xff1a;http://yuyin.baidu.com/docs/asr/54 2&#xff0c;…

Wordpress Rest API 自定义接口开发

Wordpress Rest API 自定义接口开发 背景: 我有一个需求,已经有的wordpress的接口已经无法实现这个需求的功能。我需要自己开发一个接口。接收参数并返回我希望得到的数据。 这是一篇由wordpress小白写的高级的自定义wordpress接口的教程。也是纯通过wordpress官方文档一次次…

REST架构风格

Web技术发展与REST的由来 Web&#xff08;万维网World Wide Web的简称&#xff09;是个包罗万象的万花筒&#xff0c;不同的人从不同的角度观察&#xff0c;对于Web究竟是什么会得出大不相同的观点。作为Web开发者&#xff0c;我们需要从技术上来理解Web。从技术架构层面上看&…

PHP教程:REST API示例

如果你现在正使用iphone、android以及Web等多种平台工作&#xff0c;请看一下这篇文章&#xff0c;它会告诉你如何使用PHP创建RESTful API。Representational state transfer (REST) 是一个用于向不同应用分发数据的软件系统。Web服务系统会以JSON或者XML方式响应状态码。 REST…

php REST程序设计的uml图

整理了一下目前项目中的程序设计思路&#xff0c;以此存照&#xff0c;待以后参考

windows最小化安装mysql8

第一步&#xff1a;下载 从官网下载&#xff0c;https://dev.mysql.com/downloads/mysql 第二步&#xff1a;安装 下载后解压到目录即可。 我这里解压到D:\MYSQL\mysql-8.0.32-winx64\mysql-8.0.32-winx64 第三步&#xff1a;初始化配置 1、添加系统变量 在系统变量PATH后…

php使用个推RestAPI V2开发push

个推 RestAPI V2 本来公司考虑使用腾讯云的push&#xff0c;但是腾讯云不支持uni-APP集成&#xff0c;加上我们前端使用Dcloud开发的&#xff0c;所以push使用的个推 文章目录个推 RestAPI V2前言一、Dcloud配置二、使用步骤1.获取Token2.发送push3.发送结果补充代码前言 个推…