1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
>>> a = torch.rand(30).reshape(2,3,5)
>>> a
tensor([[[0.6043, 0.8942, 0.6633, 0.7719, 0.3094],
[0.3755, 0.5932, 0.0996, 0.8829, 0.4801],
[0.1362, 0.4843, 0.2369, 0.3898, 0.5511]],
[[0.2321, 0.4191, 0.6576, 0.1157, 0.8961],
[0.6723, 0.8386, 0.2332, 0.3209, 0.8477],
[0.9402, 0.4330, 0.4449, 0.3894, 0.8684]]])
>>> a.argmax(dim=0)
tensor([[0, 0, 0, 0, 1],
[1, 1, 1, 0, 1],
[1, 0, 1, 0, 1]])
>>> a.argmax(dim=1)
tensor([[0, 0, 0, 1, 2],
[2, 1, 0, 2, 0]])
>>> a.argmax(dim=2)
tensor([[1, 3, 4],
[4, 4, 0]])
|