package main

import (
	rl "github.com/gen2brain/raylib-go/raylib"
)

type TransitionPhase int

const (
	TransitionNone TransitionPhase = iota
	TransitionOutFadeScale
	TransitionOutSlide
	TransitionInFadeScale
)

type TransitionDirection int

const (
	DirectionForward TransitionDirection = iota
	DirectionBackward
)

type TransitionManager struct {
	active    bool
	direction TransitionDirection

	oldStep int
	newStep int

	// Single total transition duration (seconds)
	totalSec float32

	// Sub-phase fractions (must sum to 1.0)
	fadeFrac  float32
	slideFrac float32
	inFrac    float32

	// Where we store the linear boundary of the sub-phase in [0..1]
	fadeEnd  float32
	slideEnd float32

	// Accumulated time in seconds
	accumSec float32
}

// NewTransitionManager with one total duration, e.g. 0.5 seconds total
// Sub-phase fractions: fade=0.3, slide=0.4, fadeIn=0.3 (they sum to 1.0)
func NewTransitionManager() *TransitionManager {
	return &TransitionManager{
		totalSec:  0.6, // e.g., entire transition 600ms
		fadeFrac:  0.3, // 30% of total time for fade out scale
		slideFrac: 0.4, // 40% of total time for slide
		inFrac:    0.3, // 30% of total time for fade in scale
	}
}

func (t *TransitionManager) Start(oldStep, newStep int) {
	t.active = true
	t.accumSec = 0
	t.oldStep = oldStep
	t.newStep = newStep

	if newStep > oldStep {
		t.direction = DirectionForward
	} else {
		t.direction = DirectionBackward
	}

	// Precompute sub-phase boundaries in [0..1]
	t.fadeEnd = t.fadeFrac
	t.slideEnd = t.fadeFrac + t.slideFrac
	// Final inFrac ends at 1.0
}

func (t *TransitionManager) IsActive() bool {
	return t.active
}

// GetPhase returns which phase the transition is currently in.
func (t *TransitionManager) GetPhase() TransitionPhase {
	if !t.active {
		return TransitionNone
	}

	// Determine the linear progress
	p := t.accumSec / t.totalSec
	if p < t.fadeEnd {
		return TransitionOutFadeScale
	} else if p < t.slideEnd {
		return TransitionOutSlide
	} else {
		return TransitionInFadeScale
	}
}

// Easing function: "easeInOutCubic" => pronounced slow start, fast middle, slow end
func easeInOutCubic(p float32) float32 {
	if p < 0.5 {
		return 4 * p * p * p
	}
	f := (2 * p) - 2
	return 0.5*f*f*f + 1
}

// Update returns alpha/scale/offset for old/new steps
func (t *TransitionManager) Update() (
	oldAlpha, oldScale, oldOffsetX float32,
	newAlpha, newScale, newOffsetX float32,
) {
	if !t.active {
		return 1, 1, 0, 1, 1, 0
	}

	// Accumulate variable time
	dt := rl.GetFrameTime() // in seconds
	t.accumSec += dt
	if t.accumSec >= t.totalSec {
		t.active = false
		return 1, 1, 0, 1, 1, 0
	}

	oldAlpha, oldScale, oldOffsetX = 1, 1, 0
	newAlpha, newScale, newOffsetX = 1, 1, 0

	slideDir := float32(1)
	if t.direction == DirectionBackward {
		slideDir = -1
	}

	p := t.accumSec / t.totalSec
	if p > 1 {
		p = 1
	}

	globalP := easeInOutCubic(p)

	if p < t.fadeEnd {
		gEnd := easeInOutCubic(t.fadeEnd)
		subE := float32(0)
		if gEnd > 0 {
			subE = globalP / gEnd
		}

		oldAlpha = 1 - 0.2*subE
		oldScale = 1 - 0.1*subE
		newAlpha = 0
		newScale = 0.9
		newOffsetX = float32(rl.GetScreenWidth()) * slideDir

	} else if p < t.slideEnd {
		gStart := easeInOutCubic(t.fadeEnd)
		gEnd := easeInOutCubic(t.slideEnd)
		gRange := gEnd - gStart
		subE := float32(0)
		if gRange > 0 {
			subE = (globalP - gStart) / gRange
		}

		oldAlpha = 0.8
		oldScale = 0.9
		oldOffsetX = -float32(rl.GetScreenWidth()) * subE * slideDir
		newAlpha = 0.8
		newScale = 0.9
		newOffsetX = float32(rl.GetScreenWidth()) * (1 - subE) * slideDir

	} else {
		gStart := easeInOutCubic(t.slideEnd)
		gEnd := easeInOutCubic(1)
		gRange := gEnd - gStart
		subE := float32(0)
		if gRange > 0 {
			subE = (globalP - gStart) / gRange
		}

		oldAlpha = 0
		oldScale = 0.9
		oldOffsetX = -float32(rl.GetScreenWidth()) * slideDir
		newAlpha = 0.8 + 0.2*subE
		newScale = 0.9 + 0.1*subE
		newOffsetX = 0
	}

	return oldAlpha, oldScale, oldOffsetX, newAlpha, newScale, newOffsetX
}