A = torch.tensor([[[1,2,3], [4,5,6], [7,8,9]], [[0,0,0], [1,1,1], [2,2,2]]]) print(A.shape)
print(A.sum())
print(A.sum(axis=0)) """ tensor([[ 1, 2, 3], [ 5, 6, 7], [ 9, 10, 11]]) """ print(A.sum(axis=0, keepdims=True)) """ tensor([[[ 1, 2, 3], [ 5, 6, 7], [ 9, 10, 11]]]) """
print(A.sum(axis=1)) """ tensor([[12, 15, 18], [ 3, 3, 3]]) """ print(A.sum(axis=1, keepdims=True)) """ tensor([[[12, 15, 18]],
[[ 3, 3, 3]]]) """
print(A.sum(axis=2)) """ tensor([[ 6, 15, 24], [ 0, 3, 6]]) """ print(A.sum(axis=2, keepdims=True)) """ tensor([[[ 6], [15], [24]],
[[ 0], [ 3], [ 6]]]) """
print(A.sum(axis=[0,1]))
print(A.sum(axis=[0,1], keepdims=True))
|