电脑基础 · 2023年3月31日

Torch计算方法

Torch 中的计算方法与 Numpy 的计算方法很类似;Torch中带 “下划线 ” 的操作,都是in-place的。

求和

torch.sum() 对输入的 tensor 数据的某一维度求和;

1.torch.sum(input, dtype=None)
2.torch.sum(input, list: dim, bool: keepdim=False, dtype=None) → Tensor
 
input:输入一个tensor
dim:要求和的维度,可以是一个列表
keepdim:求和之后这个dim的元素个数为1,如果要保留,则keepdim=True

a = torch.ones((2, 3))
a1 = torch.sum(a)
a2 = torch.sum(a, dim=0) -》 tensor([2., 2., 2.])
a3 = torch.sum(a, dim=1) -》 tensor([3., 3.])

keepdim=True 时会保持 dim 维度,不会被squeeze;

a1 = torch.sum(a, dim=(0, 1), keepdim=True)
a2 = torch.sum(a, dim=(0,), keepdim=True) -》 tensor([[2., 2., 2.]])
a3 = torch.sum(a, dim=(1,), keepdim=True) -》 tensor([[3.], [3.]])

加法

torch.add():对两个张量进行相加,格式需相同,若格式不同则以复制的方式进行扩容后再相加。
add_() 均为in-place 形式,修改了对应变量中的数值。

x = torch.arange(1., 6.)
a = torch.randn(4)
b = torch.randn(4, 1)
# alpha * b + a, 维度不够的地方自动扩容
print(torch.add(a, b, alpha=10))
p = torch.randn(4)
q = torch.randn(4)
p.add(q, alpha=10)
p.add_(q, alpha=10)

均值

mean(),dim=0 时按行求平均值,返回(1,列数);dim=1 时按列求均值,返回(行数,1),default=None 时,返回所有元素的均值。

x = torch.arange(12).view(4, 3)
'''
注意:在这里使用的时候转一下类型,否则会报RuntimeError:
Can only calculate the mean of floating types. Got Long instead.的错误。
查看了一下x元素类型是torch.int64,根据提示添加一句x=x.float()转为tensor.float32就行
'''
x = x.float()
x_mean = torch.mean(x)
x_mean0 = torch.mean(x, dim=0, keepdim=True)
x_mean1 = torch.mean(x, dim=1, keepdim=True)

乘法

mul() 与 multiply() 是同一个函数不同名称;

a = torch.randn((1, 2))
b = torch.randn((2, 1))
print(torch.mul(a, b))

matmul() ,张量乘法, 输入可以是高维数据。

dot():input 和 output 的点乘,input 和 output 都必须是一维的张量(shape 属性中只有一个值)且元素个数相同。

mm():实现线性代数中的矩阵乘法(matrix multiplication):(n×m) × (m×p) = (n×p) 。

mv():实现矩阵和向量(matrix × vector)的乘法,input 为 n×m,output 为一维张量。

减法

torch.sub(input, other, *, alpha=1, out=None)

input:被减数,张量格式
other:减数
alpha:默认为 1
out:指定 torch.sub() 输出值被赋给的变量,可不指定。

Torch计算方法

是否有限

num = torch.tensor(1)   # 数字1

res = torch.isfinite(num)  # True

num = torch.tensor(float('inf')) # 正无穷大

res = torch.isfinite(num)  # False

是否为空

res=torch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))