StreamKTilePartitionerBase< BlockGemmShapeType, ReductionStrategyType > Struct Template Reference#
Stream-K tile partitioner base class. More...
#include <streamk_gemm_tile_partitioner.hpp>
Public Types | |
| using | BlockGemmShape = BlockGemmShapeType |
Public Member Functions | |
| StreamKTilePartitionerBase (index_t m, index_t n, index_t k, index_t grid) | |
| CK_TILE_HOST_DEVICE index_t | get_partials_buffer_size (index_t acc_element_bytes) const noexcept |
| Calculates the total space needed for the partials buffer. | |
| CK_TILE_HOST_DEVICE index_t | get_flags_buffer_size () const noexcept |
| Calculates the total space needed for the flags buffer. | |
| CK_TILE_DEVICE void | get_iter_boundaries (index_t &iter_start, index_t &iter_end, index_t cta_idx) const noexcept |
| Calculates the start and end iteration given the cta_idx. | |
| CK_TILE_DEVICE index_t | get_tile_index (index_t iter_start) const noexcept |
| Calculates the 1D tile index in the C tensor for a workgroup. | |
| CK_TILE_DEVICE void | get_tile_boundaries (index_t &tile_iter_start, index_t &tile_iter_end, index_t tile_idx) const noexcept |
| Calculates the starting and ending tile boundaries for the given 1D tile index. | |
| CK_TILE_DEVICE auto | get_output_tile_index (index_t tile_idx) const noexcept -> tuple< index_t, index_t > |
| Calculates the workgroups 2D tile index in the C tensor given the 1D tile index. | |
| CK_TILE_HOST_DEVICE index_t | get_workspace_size (index_t acc_element_bytes) const noexcept |
| Calculates the total space needed for the partials and flags buffers. | |
| CK_TILE_HOST_DEVICE index_t | get_num_tiles () const noexcept |
| Returns the number of macro tiles in the C tensor. | |
| CK_TILE_HOST_DEVICE index_t | get_grid () const noexcept |
| Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy. | |
| CK_TILE_HOST_DEVICE index_t | get_dp_tiles () const noexcept |
| Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach. | |
| CK_TILE_HOST_DEVICE index_t | get_sk_tiles () const noexcept |
| Returns the number of tiles in the C tensor that will use the Stream-K approach. | |
| CK_TILE_HOST_DEVICE index_t | get_sk_ctas () const noexcept |
| Returns the number of workgroups that will participate in Stream-K in the sk_tiles_. | |
| CK_TILE_HOST_DEVICE index_t | get_total_sk_iters () const noexcept |
| Returns the total number of Stream-K iterations. | |
| CK_TILE_HOST_DEVICE index_t | get_iters_per_tile () const noexcept |
| Returns the total number of iterations per tile in the C tensor. In other words, this is the total number of macro tiles along the K dimension of A and B. | |
| CK_TILE_HOST_DEVICE index_t | get_iters_per_sk_cta () const noexcept |
| Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i.e., all sk_ctas_ are guaranteed to perform at least this many iterations). | |
| CK_TILE_HOST_DEVICE index_t | get_extra_iters () const noexcept |
| Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero, the first extra_iters_ sk_ctas_ will get one additional iteration assigned to them; such work groups will perform (iters_per_sk_cta_ + 1) iterations. | |
| CK_TILE_HOST_DEVICE index_t | get_total_dp_iters () const noexcept |
| Returns the total number of DP iterations. | |
| CK_TILE_HOST_DEVICE index_t | get_n () const noexcept |
| Returns the n dimension for the GEMM problem. | |
| CK_TILE_HOST index_t | estimate_num_wgs_per_tile () const noexcept |
| Returns an estimate of the number of workgroups writing to the same macro tile in C. | |
Static Public Member Functions | |
| static CK_TILE_DEVICE index_t | get_local_iter (index_t iter_start, index_t tile_iter_start) noexcept |
| Calculates the workgroup's starting iteration that is local to a tile. | |
| static CK_TILE_DEVICE index_t | get_local_iter_end (index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept |
| Calculates the workgroup's non-inclusive end iteration that is local to a tile. | |
Static Public Attributes | |
| static constexpr index_t | MPerBlock = BlockGemmShape::kM |
| static constexpr index_t | NPerBlock = BlockGemmShape::kN |
| static constexpr index_t | KPerBlock = BlockGemmShape::kK |
| static constexpr StreamKReductionStrategy | ReductionStrategy = ReductionStrategyType |
Protected Attributes | |
| index_t | num_tiles_ |
| index_t | grid_ |
| index_t | dp_tiles_ |
Detailed Description
struct ck_tile::StreamKTilePartitionerBase< BlockGemmShapeType, ReductionStrategyType >
Stream-K tile partitioner base class.
This partitioner is responsible for mapping workgroups to tiles in the C tensor for the Stream-K algorithm.
- Template Parameters
-
BlockGemmShapeType A class providing basic GEMM parameters. ReductionStrategyType An enum that defines the reduction strategy for the results in the C Tensor.
Member Typedef Documentation
◆ BlockGemmShape
| using ck_tile::StreamKTilePartitionerBase< BlockGemmShapeType, ReductionStrategyType >::BlockGemmShape = BlockGemmShapeType |
Constructor & Destructor Documentation
◆ StreamKTilePartitionerBase()
| ck_tile::StreamKTilePartitionerBase< BlockGemmShapeType, ReductionStrategyType >::StreamKTilePartitionerBase | ( | index_t | m, |
| index_t | n, | ||
| index_t | k, | ||
| index_t | grid ) |
Member Function Documentation
◆ estimate_num_wgs_per_tile()
|
noexcept |
Returns an estimate of the number of workgroups writing to the same macro tile in C.
◆ get_dp_tiles()
|
noexcept |
Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach.
◆ get_extra_iters()
|
noexcept |
Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero, the first extra_iters_ sk_ctas_ will get one additional iteration assigned to them; such work groups will perform (iters_per_sk_cta_ + 1) iterations.
◆ get_flags_buffer_size()
|
noexcept |
Calculates the total space needed for the flags buffer.
- Returns
- index_t The number of bytes needed for the flags buffer.
◆ get_grid()
|
noexcept |
Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy.
◆ get_iter_boundaries()
|
noexcept |
Calculates the start and end iteration given the cta_idx.
- Parameters
-
iter_start Reference to an index_t; will be set to the starting iteration by the function. iter_end Reference to an index_t; will be set to the non-inclusive end iteration by the function. cta_idx The current Stream-K workgroup's index.
- Note
- It is assumed that the first Stream-K workgroup has a cta_idx of zero. If a non-persistent DP section is used, then a Stream-K workgroup's cta_idx should be something like blockIdx.x minus number of DP workgroups.
◆ get_iters_per_sk_cta()
|
noexcept |
Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i.e., all sk_ctas_ are guaranteed to perform at least this many iterations).
◆ get_iters_per_tile()
|
noexcept |
Returns the total number of iterations per tile in the C tensor. In other words, this is the total number of macro tiles along the K dimension of A and B.
◆ get_local_iter()
|
staticnoexcept |
Calculates the workgroup's starting iteration that is local to a tile.
- Parameters
-
iter_start The starting iteration. tile_iter_start The starting iteration of the tile (i.e., the tile's starting boundary).
- Returns
- index_t The local starting iteration. The value is in range [0, iters_per_tile_).
- Note
- Assumes iter_start >= tile_iter_start.
◆ get_local_iter_end()
|
staticnoexcept |
Calculates the workgroup's non-inclusive end iteration that is local to a tile.
- Parameters
-
tile_iter_start The starting tile iteration. iter_end The non-inclusive end iteration. tile_iter_end The non-inclusive end iteration of the tile.
- Returns
- index_t The local non-inclusive end iteration.
- Note
- Assumes iter_end >= tile_iter_start and tile_iter_end >= tile_iter_start.
◆ get_n()
|
noexcept |
Returns the n dimension for the GEMM problem.
◆ get_num_tiles()
|
noexcept |
Returns the number of macro tiles in the C tensor.
◆ get_output_tile_index()
|
noexcept |
Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
- Parameters
-
tile_idx The 1D tile index in the C tensor for the workgroup.
- Returns
- index_t The corresponding 2D tile index in the C tensor for the workgroup.
◆ get_partials_buffer_size()
|
noexcept |
Calculates the total space needed for the partials buffer.
- Parameters
-
acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
- Returns
- index_t The number of bytes needed for the partials buffer.
◆ get_sk_ctas()
|
noexcept |
Returns the number of workgroups that will participate in Stream-K in the sk_tiles_.
◆ get_sk_tiles()
|
noexcept |
Returns the number of tiles in the C tensor that will use the Stream-K approach.
◆ get_tile_boundaries()
|
noexcept |
Calculates the starting and ending tile boundaries for the given 1D tile index.
◆ get_tile_index()
|
noexcept |
Calculates the 1D tile index in the C tensor for a workgroup.
- Parameters
-
iter_start The starting iteration.
- Returns
- index_t The 1D tile index.
◆ get_total_dp_iters()
|
noexcept |
Returns the total number of DP iterations.
◆ get_total_sk_iters()
|
noexcept |
Returns the total number of Stream-K iterations.
◆ get_workspace_size()
|
noexcept |
Calculates the total space needed for the partials and flags buffers.
- Parameters
-
acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
- Returns
- index_t The number of bytes needed for the partials and flags buffers.
Member Data Documentation
◆ dp_tiles_
|
protected |
◆ grid_
|
protected |
◆ KPerBlock
|
staticconstexpr |
◆ MPerBlock
|
staticconstexpr |
◆ NPerBlock
|
staticconstexpr |
◆ num_tiles_
|
protected |
◆ ReductionStrategy
|
staticconstexpr |
The documentation for this struct was generated from the following files: