gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp File Reference#
gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp File Reference
#include "ck/utility/common_header.hpp"#include "ck/tensor_description/multi_index_transform_helper.hpp"#include "ck/tensor_description/tensor_descriptor.hpp"#include "ck/tensor_description/tensor_descriptor_helper.hpp"#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"Go to the source code of this file.
Namespaces | |
| namespace | ck |
Functions | |
| template<typename GridwiseGemm, typename FloatAB, typename FloatC, typename FloatC0, typename AElementwiseOperation, typename BElementwiseOperation, typename AccElementwiseOperation, typename CElementwiseOperation, typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename C0GridDescriptor_NBlock_NPerBlock, typename Block2CTileMap, bool HasMainKBlockLoop> | |
| __global__ void | ck::kernel_gemm_layernorm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_c0_bias_grid, const FloatC0 *__restrict__ p_c0_add_grid, const FloatC0 *__restrict__ p_c0_gamma_grid, const FloatC0 *__restrict__ p_c0_beta_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map) |