This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)
*Memos:
cat() can get the 1D or more D concatenated tensor of zero or more elements without one additional dimension from the one or more 1D or more D tensors of zero or more elements as shown below:
*Memos:
-
cat()
can be used withtorch
but not with a tensor. - The 1st argument with
torch
istensors
(Required-Type:tuple
orlist
oftensor
ofint
,float
,complex
orbool
). *Basically, the size of tensors must be the same. - The 2nd argument with
torch
isdim
(Optional-Default:0
-Type:int
). - There is
out
argument withtorch
(Optional-Type:tensor
): *Memos:-
out=
must be used. -
My post explains
out
argument.
-
-
tensors+1D
tensor is returned. -
concat() is the alias of
cat()
.
import torch
tensor1 = torch.tensor([2, 7, 4])
tensor2 = torch.tensor([8, 3, 2])
tensor3 = torch.tensor([5, 0, 8])
torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8])
tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]])
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]])
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[2, 7, 4],
# [8, 3, 2],
# [5, 0, 8],
# [3, 6, 1],
# [9, 4, 7],
# [1, 0, 5]])
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[2, 7, 4, 5, 0, 8, 9, 4, 7],
# [8, 3, 2, 3, 6, 1, 1, 0, 5]])
tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
[[5, 0, 8], [3, 6, 1]]])
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
[[6, 7, 4], [2, 1, 9]]])
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
[[0, 8, 7], [3, 5, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-3)
# tensor([[[2, 7, 4], [8, 3, 2]],
# [[5, 0, 8], [3, 6, 1]],
# [[9, 4, 7], [1, 0, 5]],
# [[6, 7, 4], [2, 1, 9]],
# [[1, 6, 3], [9, 6, 0]],
# [[0, 8, 7], [3, 5, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[[2, 7, 4],
# [8, 3, 2],
# [9, 4, 7],
# [1, 0, 5],
# [1, 6, 3],
# [9, 6, 0]],
# [[5, 0, 8],
# [3, 6, 1],
# [6, 7, 4],
# [2, 1, 9],
# [0, 8, 7],
# [3, 5, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=2)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3],
# [8, 3, 2, 1, 0, 5, 9, 6, 0]],
# [[5, 0, 8, 6, 7, 4, 0, 8, 7],
# [3, 6, 1, 2, 1, 9, 3, 5, 2]]])
tensor1 = torch.tensor([[[2., 7., 4.], [8., 3., 2.]],
[[5., 0., 8.], [3., 6., 1.]]])
tensor2 = torch.tensor([[[9., 4., 7.], [1., 0., 5.]],
[[6., 7., 4.], [2., 1., 9.]]])
tensor3 = torch.tensor([[[1., 6., 3.], [9., 6., 0.]],
[[0., 8., 7.], [3., 5., 2.]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2., 7., 4.], [8., 3., 2.]],
# [[5., 0., 8.], [3., 6., 1.]],
# [[9., 4., 7.], [1., 0., 5.]],
# [[6., 7., 4.], [2., 1., 9.]],
# [[1., 6., 3.], [9., 6., 0.]],
# [[0., 8., 7.], [3., 5., 2.]]])
tensor1 = torch.tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
[8.+0.j, 3.+0.j, 2.+0.j]],
[[5.+0.j, 0.+0.j, 8.+0.j],
[3.+0.j, 6.+0.j, 1.+0.j]]])
tensor2 = torch.tensor([[[9.+0.j, 4.+0.j, 7.+0.j],
[1.+0.j, 0.+0.j, 5.+0.j]],
[[6.+0.j, 7.+0.j, 4.+0.j],
[2.+0.j, 1.+0.j, 9.+0.j]]])
tensor3 = torch.tensor([[[1.+0.j, 6.+0.j, 3.+0.j],
[9.+0.j, 6.+0.j, 0.+0.j]],
[[0.+0.j, 8.+0.j, 7.+0.j],
[3.+0.j, 5.+0.j, 2.+0.j]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
# [8.+0.j, 3.+0.j, 2.+0.j]],
# [[5.+0.j, 0.+0.j, 8.+0.j],
# [3.+0.j, 6.+0.j, 1.+0.j]],
# [[9.+0.j, 4.+0.j, 7.+0.j],
# [1.+0.j, 0.+0.j, 5.+0.j]],
# [[6.+0.j, 7.+0.j, 4.+0.j],
# [2.+0.j, 1.+0.j, 9.+0.j]],
# [[1.+0.j, 6.+0.j, 3.+0.j],
# [9.+0.j, 6.+0.j, 0.+0.j]],
# [[0.+0.j, 8.+0.j, 7.+0.j],
# [3.+0.j, 5.+0.j, 2.+0.j]]])
tensor1 = torch.tensor([[[True, False, True], [True, False, True]],
[[False, True, False], [False, True, False]]])
tensor2 = torch.tensor([[[False, True, False], [False, True, False]],
[[True, False, True], [True, False, True]]])
tensor3 = torch.tensor([[[True, False, True], [True, False, True]],
[[False, True, False], [False, True, False]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[True, False, True], [True, False, True]],
# [[False, True, False], [False, True, False]],
# [[False, True, False], [False, True, False]],
# [[True, False, True], [True, False, True]],
# [[True, False, True], [True, False, True]],
# [[False, True, False], [False, True, False]]])
tensor1 = torch.tensor([[[0, 1, 2]]])
tensor2 = torch.tensor([])
tensor3 = torch.tensor([[[0, 1, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[0., 1., 2.]],
# [[0., 1., 2.]]])
This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito) | Sciencx (2024-07-14T03:11:39+00:00) cat() in PyTorch. Retrieved from https://www.scien.cx/2024/07/14/cat-in-pytorch/
Please log in to upload a file.
There are no updates yet.
Click the Upload button above to add an update.