This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)
You can set keepdim
with the functions which have keepdim
argument as shown below:
*Memos:
- I selected some popular
keepdim
argument functions such as sum(), prod() mean(), median(), min(), max(), argmin(), argmax(), all() and any(). -
keepdim
(bool
, optional-Default:False
) keeps the dimension of theinput
tensor. - Sometimes,
keepdim
needs to be used withdim
.
sum()
. *My post explains sum()
:
import torch
my_tensor = torch.tensor([1, 2, 3, 4])
torch.sum(input=my_tensor)
torch.sum(input=my_tensor, dim=0)
# tensor(10)
torch.sum(input=my_tensor, dim=0, keepdim=True)
# tensor([10])
prod()
. *My post explains prod()
:
import torch
my_tensor = torch.tensor([1, 2, 3, 4])
torch.prod(input=my_tensor)
torch.prod(input=my_tensor, dim=0)
# tensor(24)
torch.prod(input=my_tensor, dim=0, keepdim=True)
# tensor([24])
mean()
. *My post explains mean()
:
import torch
my_tensor = torch.tensor([5., 4., 7., 7.])
torch.mean(input=my_tensor)
torch.mean(input=my_tensor, dim=0)
# tensor(5.7500)
torch.mean(input=my_tensor, dim=0, keepdim=True)
tensor([5.7500])
median()
. *My post explains median()
:
import torch
my_tensor = torch.tensor([5, 4, 7, 7])
torch.median(input=my_tensor, dim=0)
# torch.return_types.median(
# values=tensor(5),
# indices=tensor(0))
torch.median(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.median(
# values=tensor([5]),
# indices=tensor([0]))
min()
. *My post explains min()
:
import torch
my_tensor = torch.tensor([5, 4, 7, 7])
torch.min(input=my_tensor, dim=0)
# torch.return_types.min(
# values=tensor(4),
# indices=tensor(1))
torch.min(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.min(
# values=tensor([4]),
# indices=tensor([1]))
max()
. *My post explains max()
:
import torch
my_tensor = torch.tensor([5, 4, 7, 7])
torch.max(input=my_tensor, dim=0)
# torch.return_types.max(
# values=tensor(7),
# indices=tensor(2))
torch.max(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.max(
# values=tensor([7]),
# indices=tensor([2]))
argmin()
. *My post explains argmin()
:
import torch
my_tensor = torch.tensor([5, 4, 7, 7])
torch.argmin(input=my_tensor)
torch.argmin(input=my_tensor, dim=0)
# tensor(1)
torch.argmin(input=my_tensor, keepdim=True)
torch.argmin(input=my_tensor, dim=0, keepdim=True)
# tensor([1])
argmax()
. *My post explains argmax()
:
import torch
my_tensor = torch.tensor([5, 4, 7, 7])
torch.argmax(input=my_tensor)
torch.argmax(input=my_tensor, dim=0)
# tensor(2)
torch.argmax(input=my_tensor, keepdim=True)
torch.argmax(input=my_tensor, dim=0, keepdim=True)
# tensor([2])
all()
. *My post explains all()
:
import torch
my_tensor = torch.tensor([True, False, True, False])
torch.all(input=my_tensor)
torch.all(input=my_tensor, dim=0)
# tensor(False)
torch.all(input=my_tensor, keepdim=True)
torch.all(input=my_tensor, dim=0, keepdim=True)
# tensor([False])
any()
. *My post explains any()
:
import torch
my_tensor = torch.tensor([True, False, True, False])
torch.any(input=my_tensor)
torch.any(input=my_tensor, dim=0)
# tensor(True)
torch.any(input=my_tensor, keepdim=True)
torch.any(input=my_tensor, dim=0, keepdim=True)
# tensor([True])
This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito) | Sciencx (2024-06-28T03:32:49+00:00) Set keepdim with keepdim argument functions PyTorch. Retrieved from https://www.scien.cx/2024/06/28/set-keepdim-with-keepdim-argument-functions-pytorch/
Please log in to upload a file.
There are no updates yet.
Click the Upload button above to add an update.