您好,欢迎访问代理记账网站
  • 价格透明
  • 信息保密
  • 进度掌控
  • 售后无忧

pytorch基本语法学习(三)

pytorch基本语法学习(三)

经过两个小节的学习,本小姐是基础知识的最后一个,主要是tensor的拼接与拆分,还会简单介绍一些相关的tensor数学运算和统计。

tensor的拼接 cat & stack

a = torch.FloatTensor(4, 3, 28, 32)
b = torch.FloatTensor(9, 3, 28, 32)
c = torch.FloatTensor(9, 3, 28, 32)
d = torch.FloatTensor(4, 5, 28, 32)
print(torch.cat([a,b], dim=0).shape)  # torch.Size([13, 3, 28, 32])
print(torch.cat([a,d], dim=1).shape)  # torch.Size([4, 8, 28, 32])
print(torch.stack([b,c],dim=0).shape) # torch.Size([2, 9, 3, 28, 32])
print(torch.stack([b,c],dim=1).shape) # torch.Size([9, 2, 3, 28, 32])
print(torch.stack([b,c],dim=2).shape) # torch.Size([9, 3, 2, 28, 32])
print(torch.stack([b,c],dim=3).shape) # torch.Size([9, 3, 28, 2, 32])

tensor的拆分 split & chunk

a = torch.FloatTensor(9, 3, 28, 32)
b = torch.FloatTensor(9, 3, 28, 32)
d = torch.FloatTensor(9, 3, 28, 32)
c = torch.stack([a,b],dim=0)
print(c.shape)  # torch.Size([2, 9, 3, 28, 32])
e = torch.stack([a,b,d],dim=0)
print(e.shape)  # torch.Size([3, 9, 3, 28, 32])
aa, bb = c.split([1,1],dim=0)  # 每份分多少
print(aa.shape)  # torch.Size([1, 9, 3, 28, 32])
print(bb.shape)  # torch.Size([1, 9, 3, 28, 32])
aa1, bb1 = c.split(1,dim=0)
print(aa1.shape)  # torch.Size([1, 9, 3, 28, 32])
print(bb1.shape)  # torch.Size([1, 9, 3, 28, 32])
aa2, bb2, cc2 = e.split(1,dim=0)
print(aa2.shape)  # torch.Size([1, 9, 3, 28, 32])
print(bb2.shape)  # torch.Size([1, 9, 3, 28, 32])
print(cc2.shape)  # torch.Size([1, 9, 3, 28, 32])
aa3, bb3 = e.split([1,2],dim=0)
print(aa3.shape)  # torch.Size([1, 9, 3, 28, 32])
print(bb3.shape)  # torch.Size([2, 9, 3, 28, 32])
aa4, bb4 = e.split([1,2],dim=2)
print(aa4.shape)  # torch.Size([3, 9, 1, 28, 32])
print(bb4.shape)  # torch.Size([3, 9, 2, 28, 32])

aa5, bb5 = e.chunk(2,dim=0)  # 分成几份
print(aa5.shape)  # torch.Size([2, 9, 3, 28, 32])
print(bb5.shape)  # torch.Size([1, 9, 3, 28, 32])
aa6, bb6, cc6 = e.chunk(3,dim=0)
print(aa6.shape)  # torch.Size([1, 9, 3, 28, 32])
print(bb6.shape)  # torch.Size([1, 9, 3, 28, 32])
print(cc6.shape)  # torch.Size([1, 9, 3, 28, 32])

tensor的数学运算

# 加、减、乘、除
a = torch.ones(2,3,4)
b = torch.ones(4)  # 完全可以按照最直接的方式编写,+,-,*,/
print(torch.all(torch.eq(a + b,torch.add(a,b))))  # tensor(True)
print(torch.all(torch.eq(a - b,torch.sub(a,b))))  # tensor(True)
print(torch.all(torch.eq(a * b,torch.mul(a,b))))  # tensor(True)
print(torch.all(torch.eq(a / b,torch.div(a,b))))  # tensor(True)

# 矩阵相乘 建议采用torch.matmul,因为torch.mm only for 2d
a = torch.FloatTensor([[3,3],[3,3]])
print(a.shape) # torch.Size([2,2])
b = torch.ones(2,2)
print(torch.mm(a, b))  # tensor([[6., 6.],
                                # [6., 6.]])
print(torch.all(torch.eq(torch.mm(a, b),torch.matmul(a,b))))  # tensor(True)

# 矩阵的平方
a = torch.FloatTensor([[3,3],[3,3]])
aa = a**2
print(torch.all(torch.eq(a.pow(2),a**2)))  # tensor(True)
print(aa.sqrt()) # 开平方  tensor([[3., 3.],
                                 # [3., 3.]])
print(torch.exp(a))  # e的a次方 tensor([[20.0855, 20.0855],
                                     # [20.0855, 20.0855]])
print(torch.log(a))  # tensor([[1.0986, 1.0986],
                            # [1.0986, 1.0986]])

b = torch.tensor(3.14)
print(b.floor())  # tensor(3.)
print(b.ceil())  # tensor(4.)
print(b.trunc())  # tensor(3.)
print(b.frac())  # tensor(0.1400)

tensor的统计

a = torch.full([8],1.0)
b = a.view(2,4)
c = a.view(2,2,2)
print(b) # tensor([[1., 1., 1., 1.],
                  # [1., 1., 1., 1.]])
print(c)  # tensor([[[1., 1.],
                 # [1., 1.]],

                 # [[1., 1.],
                 # [1., 1.]]])
print(a.norm(1))  # tensor(8.)
print(b.norm(1))  # tensor(8.)
print(c.norm(1))  # tensor(8.)
print(a.norm(2))  # tensor(2.8284)
print(b.norm(2))  # tensor(2.8284)
print(c.norm(2))  # tensor(2.8284)
print(b.norm(1, dim=1))  # tensor([4., 4.])
print(b.norm(2, dim=1))  # tensor([2., 2.])
print(c.norm(1, dim=0))  # tensor([[2., 2.],
                                  # [2., 2.]])
print(c.norm(2, dim=0))  # tensor([[1.4142, 1.4142],
                                  # [1.4142, 1.4142]])

a = torch.arange(8.0).view(2,4)
print(a)  # tensor([[0., 1., 2., 3.],
                   # [4., 5., 6., 7.]])
print(a.min())  # tensor(0.)
print(a.max())  # tensor(7.)
print(a.mean())  # tensor(3.5000)
print(a.prod())  # tensor(0.)
print(a.sum())  # tensor(28.)
print(a.argmax())  # tensor(7)
print(a.argmin())  # tensor(0)

截至目前基本的语法已经学习完成了,下一步就需要上深度学习模型了。加油加油…


分享:

低价透明

统一报价,无隐形消费

金牌服务

一对一专属顾问7*24小时金牌服务

信息保密

个人信息安全有保障

售后无忧

服务出问题客服经理全程跟进