gemm_aquant_pipeline_ag_bg_cr_base.hpp Source File

gemm_aquant_pipeline_ag_bg_cr_base.hpp Source File#

Composable Kernel: gemm_aquant_pipeline_ag_bg_cr_base.hpp Source File
gemm_aquant_pipeline_ag_bg_cr_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <typename Problem, typename Policy>
14{
16 using ADataType = typename Base::ADataType;
17 using ALayout = typename Base::ALayout;
18 using BDataType = typename Base::BDataType;
19 using BLayout = typename Base::BLayout;
22
24
25 static constexpr index_t MPerBlock = BlockGemmShape::kM;
26 static constexpr index_t NPerBlock = BlockGemmShape::kN;
27 static constexpr index_t KPerBlock = BlockGemmShape::kK;
28
29 static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
30
31 static_assert(KPerBlock % QuantGroupSize::kK == 0,
32 "KPerBlock must be a multiple of QuantGroupSize");
33
34 // Create DRAM tile window for AQ
35 template <typename AQDramBlockWindowTmp>
36 CK_TILE_DEVICE constexpr auto
37 GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
38 {
39 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
40
41 auto aq_copy_dram_window =
42 make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
43 aq_dram_block_window_tmp.get_window_lengths(),
44 aq_dram_block_window_tmp.get_window_origin(),
45 Policy::template MakeAQDramTileDistribution<Problem>());
46 return aq_copy_dram_window;
47 }
48};
49
50} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
int32_t index_t
Definition integer.hpp:9
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:14
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:21
remove_cvref_t< typename Problem::AQLayout > AQLayout
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:23
typename Base::ADataType ADataType
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:16
typename Base::BlockGemmShape BlockGemmShape
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t NPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:26
static constexpr index_t KPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:27
typename Base::ALayout ALayout
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:17
typename Base::BDataType BDataType
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:18
CK_TILE_DEVICE constexpr auto GetAQDramLoadWindow(const AQDramBlockWindowTmp &aq_dram_block_window_tmp) const
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:37
GemmPipelineAgBgCrImplBase< Problem, Policy > Base
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:15
static constexpr index_t KPerBlockAQ
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:29
static constexpr index_t MPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:25
typename Base::BLayout BLayout
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:19
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:23
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:21
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_base.hpp:18
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20