NVIDIA 의 CuTe 는 이해하기 매우 어렵다. 하지만 그의 바탕이 되는 이론은 정교하고 이를 활용한다면 높은 퍼포먼스의 코딩을 쉽게 할 수 있게 도와준다.
CuTe 의 개념에 대한 설명은 https://docs.nvidia.com/cutlass/media/docs/cpp/cute/index.html 에 적혀있으나, 이를 보고 이해하는 것은 매우 어렵다. 본 글은 CuTe 를 처음 접하는 이들에게 CuTe 의 기본 개념을 익힐 수 있는 기회를 제공하고자한다.
Layout
CuTe 의 중요 개념은 Layout 이다. CuTe 의 모든 것은 바로 Layout 에서 시작한다고 보면 된다. Layout 은 Tensor 의 data 를 어떻게 저장하고 다룰 것인지 정의하는 개념이자 도구이다.
Tensor 가 1의 stride 로 10개의 data point [1, 2, 3, …, 10] 을 가지고 있다면 이 tensor 를 다루기 위한 Layout 은 shape = 10, stride = 1 일 것이다.
Tensor 는 1차원으로 제한되지 않는다. 2차원 tensor, 3차원 tensor, 100 차원 tensor 도 존재할 수 있다. 즉 Layout 또한 그러한 Tensor 들을 다루기 위해 여러 차원을 표시할 수 있어야한다.
3x4 shape 의 2차원 tensor (matrix)를 생각해보자.
1, 2, 3, 4
5, 6, 7, 8
9, 10, 11, 12tensor 의 data 를 담고 있는 시작 주소가 0x00 이라 한다면 각 element 를 담고있는 주소는 어디일까?
data type 이 float 이라면 높은 확률로 아래와 같은 주소를 가질 것이라 예상할 수 있다.
1(0x00), 2(0x04), 3(0x08), 4(0x0C)
5(0x10), 6(0x14), 7(0x18), 8(0x1C)
9(0x20), 10(0x24), 11(0x28), 12(0x2C)하지만 항상 위와 같은 주소를 가지는 것은 아니다. pytorch 의 예시를 들어보자.
아래 코드는 위와 똑같은 값을 가지는 matrix 를 출력한다. 하지만 각 element 의 주소는 이전과는 다른 주소를 가진다. 각 element 의 주소를 출력과 같이 표기해보았다.
>>> arr = torch.tensor([
[1, 2, 3, 4, 0,],
[5, 6, 7, 8, 0,],
[9, 10, 11, 12, 0,],
], dtype=torch.float)[:, :-1]
>>> print_tensor_with_address(arr)
1(0x00), 2(0x04), 3(0x08), 4(0x0C)
5(0x14), 6(0x18), 7(0x1C), 8(0x20)
9(0x28), 10(0x2C), 11(0x30), 12(0x34)위의 예시에서 알 수 있듯 우리는 시작 주소를 안다한들 matrix 의 값만 봐서는 각 element 의 주소를 알 수 없다.
이를 해결하기 위한 개념이 바로 Stride 다. Stride 는 각 dim 의 값들이 일정 크기의 간격을 갖고있다는 가정에서 출발한다.
바로 위의 예시에서 각 column (dim=1) 의 이웃한 값들은 0x04 만큼의 간격을 갖고 있다. 각 row (dim=0) 의 이웃한 값들은 0x14 의 간격을 갖고 있다. 이는 아래의 코드로도 확인할 수 있다.
>>> print(f"row stride = 0x{arr.stride(0) * 4:02X}, "
f"col stride = 0x{arr.stride(1) * 4:02X}")
row stride = 0x14, col stride = 0x04Layout 역시 shape 와 stride 를 통해 tensor 을 규정한다.
Layout 은 Shape:Stride 의 형태로 나타내진다. 다음은 Layout 의 예시들이다.
- 4:1
- 4:8
- (3, 2):(1, 3) - shape = (3, 2), stride = (1, 3)
- (3, 2):(2, 1) - shape = (3, 2), stride = (2, 1)
- (2, (2, 2)):(4, (2, 1)) - shape = (2, (2, 2)), stride = (4, (2, 1))
각 숫자들을 mode 라고 부른다. (3, 2) 의 first mode 는 3, second mode 는 2.
사실 Layout 은 shape space 의 tuple 을 받아서 index space 의 integer 출력하는 하나의 함수이다.
가령 Layout 4:8 에 2 을 input 으로 넣으면 16 이 output 로 나온다. 즉 Layout 은 input 으로 넣은 coordinate 에 위치한 값이 저장된 주소가 어딘지 출력해주는 함수이다.
출력되는 값을 계산하는 방법은 아주 간단한데, 각 coord 와 stride 을 곱한 뒤 합해준 값을 계산하면 된다.
Layout (3, 2):(2, 1) 에 coordinate (2, 0) 을 input 으로 넣으면 2 * 2 + 1 * 0 = 4 가 출력된다.
Layout (3, 2):(2, 1) 의 각 coordinate 별 index 값을 출력하면 다음과 같다.
Layout (3, 2):(2, 1) 의 coordinate 는 2 차원으로 표시할 수 있지만, 1차원으로 나타낼 수도 있다.
template <class Shape, class Stride>
void print1D(Layout<Shape,Stride> const& layout)
{
for (int m = 0; m < size(layout); ++m) {
printf("%3d ", layout(m));
}
printf("\n");
}
output:
0 2 4 1 3 5 조금더 복잡한 Layout 에 대해 살펴보자.
(2,(2,2)):(4,(2,1)) 의 coordinate 은 기본적으로 3 차원이지만 1차원, 2차원으로도 접근 가능하다.
사실 수학적으론 (2, (2, 2)) 와 같이 괄호를 중첩해서 쓰는 것은 의미가 없고 (2, 2, 2) 로만 표현해도 충분하다.
Coalesce
Coalesce 는 Layout 을 단순화 하는 것이다.
Layout (2,(3,1)):(1,(2,6)) 을 예로 들어보자. print1D(layout) 의 결과는 다음과 같을 것이다.
auto layout = make_layout(
make_shape(Int<2>{}, make_shape(Int<3>{}, Int<1>{})),
make_stride(Int<1>{}, make_stride(Int<2>{}, Int<6>{}))
);
print(layout);
print("\n------------\n");
print1D(layout);
output:
0 1 2 3 4 5 여기서 우리는 Layout (2,(3,1)):(1,(2,6)) 이 사실상 Layout 6:1 과 같다는 사실을 알 수 있다.
마찬가지로 Layout (2, 4):(2, 4) 는 Layout 8:2 와 같다.
그럼 어떤 Layout 이 주어졌을 때 그 layout 과 사실상 같은 1D layout 이 존재할지 안할지 어떻게 알 수 있을까?
이를 알아보기 위해 Layout 에 대해 아래의 4가지 경우로 나눠보자.
- = 1
- = 1
- 그외의 경우
- = 1 인 경우 자명하게 와 같다
- = 1 인 경우 과 같다.
- 의 경우 과 같다.
- 단순화 할 수 없다.
이 과정을 recursive 하게 반복하면 layout 의 단순화 된 표현이 있는지 없는지 알 수 있다.
(2, (3, 1)):(1, (2, 6)) → (2, 3):(1, 2) → (6, 1)
(2, 4):(2, 4) → (8, 2)
(4, 2):(4, 2) 는 coalesce 되지 않는다. 여기서 우리는 stride 에 순서라는 개념이 있다는 것을 엿볼수 있다
Composition
composition 은 Layout 의 곱을 구하는 것이다.
CuTe 에선 아래의 함수로 composition 을 쉽게 구할 수 있다.
Layout composition(LayoutA const& layout_a, LayoutB const& layout_b)coalesce 에서 쉽게 판단하는 방법이 있었듯 composition 을 쉽게 계산하는 방법 역시 존재한다.
예시를 들어서 설명하겠다.
B 의 mode 가 2개 이상일 때 규칙을 직관적으로 발견하긴 어렵다. 하지만 B 의 mode 가 1개 일때는 아래와 같은 형태를 보인다.
A = (4, 8):(13, 1)
B = 8:2
(4, 8):(13, 1) / 2 = (2, 8):(13 * 2, 1)
(2, 8):(26, 1) % 8 = (2, 4):(26, 1)
A o B = (2, 4):(26, 1)
A
0 1 2 3 4 5 6 7
13 14 15 16 17 18 19 20
26 27 28 29 30 31 32 33
39 40 41 42 43 44 45 46
B
0
2
4
6
8
10
12
14
A o B
0 1 2 3
26 27 28 29A 를 2개 element 간격으로 추려낸 것이 A o B 임을 알 수 있다.
By-mode composition
By-mode composition 은 layer 와 layer 을 composition 하는 것과는 달리
각 layer 의 mode 끼리 composition 을 수행하는 것이다.
앞서 살펴봤듯 mode 끼리 composition 을 수행하는 것은 일종의 indexing 효과를 지닌다. (표현이 애매하다)
CuTe 에서 By-mode composition 은 tile 을 이용해서 구현할 수 있다.
그림에서 볼 수 있듯 composition 은 Layer 의 일부 Tile 을 선택하는 것으로 볼 수 있다.
Complement
complement 는 Layout A 와 A 보다 큰 Layout B 가 주어졌을 때 Layout A 를 어떻게 반복해서 Layout B 를 cover 할 수 있을지에 관한 함수이다.
0, 6
1, 7위의 layer 를 (3, 2):(2, 12) 로 layering 하면 24 가 되는 것을 visual 로 확인 할 수 있다. (submatrix 같은 개념을 떠올리면 된다.)
Layout 을 그의 complement 와 concatenate 하면 Layout B 와 isomorphic 한 layout 이 나온다.
concatenate 해서 복원한 isomorphic 한 layer 는 다른 indexing 을 가질 수 있다. 이 indexing 은 Layout A 가 반복되는 패턴을 가질 것이다.
Division (Tiling)
template <class LShape, class LStride,
class TShape, class TStride>
auto logical_divide(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tiler)
{
return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
}Consider tiling the 1-D layout A = (4, 2, 3):(2, 1, 8) with the tiler B = 4:2.
- Complement of
B = 4:2undersize(A) = 24isB* = (2,3):(1,8). - Concantenation of
(B,B*) = (4,(2,3)):(2,(1,8)). - Composition of
A = (4,2,3):(2,1,8)with(B,B*)is then((2,2),(2,3)):((4,1),(2,8)).
Division 이 무엇인지 알기위해 차례대로 살펴보자.
먼저 make_layout(tiler, complement(tiler, size(layout))) 부분을 보면 layout 를 tiler 의 형태로 복원한다.
layout 과 복원된 layout 를 composition 한다. 이는 layout 에서 같은 tile 에 있는 것들을 grouping 해주는 역할을 수행한다.
Product (Tiling)
template <class LShape, class LStride,
class TShape, class TStride>
auto logical_product(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tiler)
{
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
}product 는 직관적인 이해가 쉽다.
이들을 종합하면 아래와 같은 연산이 가능하다
아래와 같이 Tensor E 를 잡는다면 어떤 figure 가 그려질지 생각해보자.
Tensor E = A(make_coord(_,1),make_coord(1,_,1));Layout Shape : (M, N, L, ...)
Tiler Shape : <TileM, TileN>
logical_divide : ((TileM,RestM), (TileN,RestN), L, ...)
zipped_divide : ((TileM,TileN), (RestM,RestN,L,...))
tiled_divide : ((TileM,TileN), RestM, RestN, L, ...)
flat_divide : (TileM, TileN, RestM, RestN, L, ...)Inner and outer partitioning
Let’s take a tiled example and look at how we can slice it in useful ways.
아래 코드에서 볼 수 있듯 local_tile (=zipped_divide + coord) 의 과정에서 stride 는 중요하지 않음. 물론 output 에는 stride 가 영향을 줌
참고로 make_identity_tensor 는 아래와 같은 함수
# Create a simple 1D coord tensor
tensor = make_identity_tensor(6) # [0,1,2,3,4,5]
# Create a 2D coord tensor
tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)]
# Create hierarchical coord tensor
tensor = make_identity_tensor(((2,1),3))
# [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)]아래와 같이 local_partition 은 layout 의 stride 가 중요하다. local_parition 은 tid 를 layout 의 역함수를 이용하여 coordinate 로 변환하기 때문.
Suppose that we want to give each threadgroup one of these 4x8 tiles of data. Then we can use our threadgroup coordinate to index into the second mode.
Tensor cta_a = tiled_a(make_coord(_,_), make_coord(blockIdx.x, blockIdx.y)); // (_4,_8)
We call this an inner-partition because it keeps the inner “tile” mode. This pattern of applying a tiler and then slicing out that tile by indexing into the remainder mode is common and has been wrapped into its own function inner_partition(Tensor, Tiler, Coord). You’ll often see local_tile(Tensor, Tiler, Coord) which is just another name for inner_partition. The local_tile partitioner is very often applied at the threadgroup level to partition tensors into tiles across threadgroups.
Alternatively, suppose that we have 32 threads and want to give each thread one element of these 4x8 tiles of data. Then we can use our thread to index into the first mode.
Tensor thr_a = tiled_a(threadIdx.x, make_coord(_,_)); // (2,3)
We call this an outer-partition because it keeps the outer “rest” mode. This pattern of applying a tiler and then slicing into that tile by indexing into the tile mode is common and has been wrapped into its own function outer_partition(Tensor, Tiler, Coord). Sometimes you’ll see local_partition(Tensor, Layout, Idx), which is a rank-sensitive wrapper around outer_partition that transforms the Idx into a Coord using the inverse of the Layout and then constructs a Tiler with the same top-level shape of the Layout. This allows the user to ask for a row-major, column-major, or arbitrary layout of threads with a given shape that can be used to partition into a tensor.
/