# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Common utilities for torch-native-parallelism examples.
"""

import time
from contextlib import nullcontext

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import Accelerator


def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Accelerator | None = None) -> Dataset:
    """
    Load and prepare TinyStories dataset.

    Args:
        accelerator (Accelerator): Accelerate accelerator instance
        tokenizer (AutoTokenizer): Hugging Face tokenizer
        seq_len (int): Sequence length for the dataset

    Returns:
        Dataset: Packed dataset
    """
    processing_ctx = accelerator.main_process_first if accelerator else nullcontext
    raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")

    def tokenize_function(examples):
        tokenized_batch = tokenizer(
            examples["text"],
            padding=False,
            truncation=True,
            max_length=seq_len,
            return_tensors=None,
        )
        tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
        return tokenized_batch

    with processing_ctx():
        tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    def create_packed_sequences(examples):
        all_tokens = []
        for input_ids in examples["input_ids"]:
            all_tokens.extend(input_ids)

        num_sequences = len(all_tokens) // (seq_len + 1)
        packed_input_ids = []
        packed_labels = []
        packed_position_ids = []

        for i in range(num_sequences):
            start_idx = i * (seq_len + 1)
            end_idx = start_idx + (seq_len + 1)
            full_sequence = all_tokens[start_idx:end_idx]
            packed_input_ids.append(full_sequence[:-1])
            packed_labels.append(full_sequence[1:])
            packed_position_ids.append(torch.arange(0, seq_len))

        return {
            "input_ids": packed_input_ids,
            "shift_labels": packed_labels,
            "position_ids": packed_position_ids,
            "labels": packed_labels,
        }

    with processing_ctx():
        packed_dataset = tokenized_dataset.map(
            create_packed_sequences,
            batched=True,
            remove_columns=tokenized_dataset.column_names,
            batch_size=1000,
        )

    return packed_dataset.shuffle(seed=42)


def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
    """
    Get the number of flops per token for the model.

    Args:
        model (AutoModelForCausalLM): Model to get the flops for
        seq_len (int): Sequence length
    """
    cfg = model.config
    head_dim = cfg.hidden_size // cfg.num_attention_heads

    # MLP: 3 matmuls
    mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size

    # Attn (w/o dotproduct)
    attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)

    # attn (dotproduct) - this scales quadratically with sequence length
    attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len

    # we also ignore embeddings and layernorms, etc
    return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers


def create_collate_fn():
    """Create a collate function for batching."""

    def collate_fn(batch):
        input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
        shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
        return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}

    return collate_fn


class PerformanceTracker:
    """Track training performance metrics."""

    def __init__(self, warmup_steps: int = 10):
        self.warmup_steps = warmup_steps
        self.reset()

    def reset(self):
        """Reset all tracking variables."""
        self.start_time = None
        self.num_tokens = 0
        self.is_in_warmup = True
        self.step_count = 0

    def step(self, batch_tokens: int, model_flops_per_token: float | None = None) -> dict:
        """
        Update performance tracking with a new step.

        Args:
            batch_tokens (int): Number of tokens in current batch

        Returns:
            dict: Performance metrics if past warmup, empty dict otherwise
        """
        self.step_count += 1

        if self.step_count == self.warmup_steps:
            self.start_time = time.perf_counter()
            self.num_tokens = 0
            self.is_in_warmup = False
            return {"warmup_completed": True}

        if not self.is_in_warmup and self.start_time is not None:
            dct = {}
            self.num_tokens += batch_tokens
            total_time = time.perf_counter() - self.start_time
            steps_from_warmup = self.step_count - self.warmup_steps

            if total_time > 0 and steps_from_warmup > 0:
                memory_stats = gpu_memory_usage_all()
                dct = {
                    "tokens_per_second": self.num_tokens / total_time,
                    "steps_per_second": steps_from_warmup / total_time,
                    "total_tokens": self.num_tokens,
                    "total_time": total_time,
                    **memory_stats,
                }

            if model_flops_per_token is not None:
                flops = model_flops_per_token * self.num_tokens
                dct["tflops_per_device"] = flops / (total_time * 1e12)

            return dct

        return {}

    def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
        print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f} | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
        if with_memory:
            print_msg += (
                f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
                f"alloc={metrics['peak_memory_alloc']:.1f}, "
                f"reserved={metrics['peak_memory_reserved']:.1f}"
            )
        return print_msg


def setup_tokenizer(model_id: str) -> AutoTokenizer:
    """Setup tokenizer with proper padding token."""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def gpu_memory_usage_all(device=0):
    device_type = torch.accelerator.current_accelerator().type
    device = torch.device(f"{device_type}:{device}")
    torch_device_module = getattr(torch, device_type, torch.cuda)
    _BYTES_IN_GIB = 1024**3
    peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
    peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
    peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
    memory_stats = {
        "peak_memory_active": peak_memory_active,
        "peak_memory_alloc": peak_memory_alloc,
        "peak_memory_reserved": peak_memory_reserved,
    }
    torch_device_module.reset_peak_memory_stats(device)

    return memory_stats
