官网说明:
torch.topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
Returns the k
largest elements of the given input
tensor along a given dimension.
If dim
is not given, the last dimension of the input is chosen.
If largest
is False
then the k smallest elements are returned.
A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.
The boolean option sorted
if True
, will make sure that the returned k elements are themselves sorted
parameters:
input () – the input tensor.
k () – the k in “top-k”
dim (, optional) – the dimension to sort along
largest (, optional) – controls whether to return largest or smallest elements
sorted (, optional) – controls whether to return the elements in sorted order
out (, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
举个例子:
x = torch.rand(3,5)
print(x)
indexs = x.topk(k=1,dim=1)
print(indexs[1])
tensor([[0.9068, 0.6301, 0.3500, 0.3612, 0.8632],
[0.6435, 0.7596, 0.5890, 0.4887, 0.3763],
[0.7244, 0.7431, 0.2717, 0.7388, 0.7798]])
tensor([[0],
[1],
[4]])
因篇幅问题不能全部显示,请点此查看更多更全内容