package main

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"os"
	"time"

	"github.com/jackc/pgx/v5/pgxpool"
)

type DB struct {
	pool *pgxpool.Pool
}

func initDB(ctx context.Context) (*DB, error) {
	dsn := os.Getenv("DATABASE_URL")
	if dsn == "" {
		dsn = "postgresql://postgres:123456@localhost:5432/ic3?sslmode=disable"
	}
	pool, err := pgxpool.New(ctx, dsn)
	if err != nil {
		return nil, err
	}
	if err := pool.Ping(ctx); err != nil {
		pool.Close()
		return nil, err
	}
	if err := runMigrations(ctx, pool); err != nil {
		pool.Close()
		return nil, fmt.Errorf("migrations: %w", err)
	}
	log.Println("PostgreSQL connected, all migrations applied")
	return &DB{pool: pool}, nil
}

func (db *DB) Close() { db.pool.Close() }

// saveBatch writes each envelope to:
//   - telemetry       (append-only history)
//   - asset_latest    (upsert – latest state per asset)
//   - assets          (upsert – registry, tracks first/last seen)
//   - alarm_events    (insert on alarm transition away from NORMAL)
func (db *DB) saveBatch(ctx context.Context, batch []TelemetryEnvelope) {
	for _, env := range batch {
		payload, _ := json.Marshal(env.Values)

		var ts any
		if env.OccurredAt != "" {
			ts = env.OccurredAt
		} else {
			ts = time.Now().UTC()
		}

		// 1. Append to telemetry history
		db.pool.Exec(ctx, `
			INSERT INTO telemetry (
				record_id, timestamp_utc, domain_code, system_id, asset_id,
				source_system, collector_id, sequence_no,
				quality_code, alarm_state,
				location_path, country_code, state_code, district_code,
				taluk_code, city_code, ward_code,
				zone_id, dma_id, pressure_zone_id, site_id,
				latitude, longitude, elevation_m,
				stream_topic, priority, payload
			) VALUES (
				$1,$2,$3,$4,$5,
				$6,$7,$8,
				$9,$10,
				$11,$12,$13,$14,
				$15,$16,$17,
				$18,$19,$20,$21,
				$22,$23,$24,
				$25,$26,$27
			)
			ON CONFLICT (record_id) DO NOTHING`,
			nullText(env.RecordID), ts, env.Domain, env.SystemID, env.AssetID,
			nullText(env.SourceSystem), nullText(env.CollectorID), nullInt(env.SequenceNo),
			coalesce(env.Quality, "GOOD"), coalesce(env.Alarm, "NORMAL"),
			nullText(env.LocationPath), nullText(env.CountryCode), nullText(env.StateCode), nullText(env.DistrictCode),
			nullText(env.TalukCode), nullText(env.CityCode), nullText(env.WardCode),
			nullText(env.ZoneID), nullText(env.DMAID), nullText(env.PressureZoneID), nullText(env.SiteID),
			nullFloat(env.Latitude), nullFloat(env.Longitude), nullFloat(env.ElevationM),
			nullText(env.StreamTopic), nullText(env.Priority), payload,
		)

		// 2. Upsert asset_latest
		db.pool.Exec(ctx, `
			INSERT INTO asset_latest (
				asset_id, system_id, domain_code, timestamp_utc,
				quality_code, alarm_state, dma_id, site_id, stream_topic, payload
			) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
			ON CONFLICT (asset_id) DO UPDATE SET
				system_id     = EXCLUDED.system_id,
				domain_code   = EXCLUDED.domain_code,
				timestamp_utc = EXCLUDED.timestamp_utc,
				quality_code  = EXCLUDED.quality_code,
				alarm_state   = EXCLUDED.alarm_state,
				dma_id        = EXCLUDED.dma_id,
				site_id       = EXCLUDED.site_id,
				stream_topic  = EXCLUDED.stream_topic,
				payload       = EXCLUDED.payload,
				updated_at    = NOW()`,
			env.AssetID, env.SystemID, env.Domain, ts,
			coalesce(env.Quality, "GOOD"), coalesce(env.Alarm, "NORMAL"),
			nullText(env.DMAID), nullText(env.SiteID), nullText(env.StreamTopic), payload,
		)

		// 3. Upsert asset registry
		db.pool.Exec(ctx, `
			INSERT INTO assets (
				asset_id, system_id, domain_code, site_id, dma_id, zone_id,
				location_path, latitude, longitude, elevation_m,
				first_seen, last_seen, last_quality, last_alarm, total_records
			) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10, $11,$11,$12,$13,1)
			ON CONFLICT (asset_id) DO UPDATE SET
				last_seen     = EXCLUDED.last_seen,
				last_quality  = EXCLUDED.last_quality,
				last_alarm    = EXCLUDED.last_alarm,
				total_records = assets.total_records + 1`,
			env.AssetID, env.SystemID, env.Domain, nullText(env.SiteID), nullText(env.DMAID), nullText(env.ZoneID),
			nullText(env.LocationPath), nullFloat(env.Latitude), nullFloat(env.Longitude), nullFloat(env.ElevationM),
			ts, coalesce(env.Quality, "GOOD"), coalesce(env.Alarm, "NORMAL"),
		)

		// 4. Record alarm transitions (NORMAL→anything or anything→NORMAL)
		if env.Alarm != "" && env.Alarm != "NORMAL" {
			alarmID := env.RecordID + ":alarm"
			if alarmID == ":alarm" {
				alarmID = env.AssetID + ":" + fmt.Sprint(time.Now().UnixNano())
			}
			db.pool.Exec(ctx, `
				INSERT INTO alarm_events (
					record_id, timestamp_utc, asset_id, domain_code, system_id,
					dma_id, from_state, to_state, quality_code
				) VALUES ($1,$2,$3,$4,$5,$6,'NORMAL',$7,$8)
				ON CONFLICT (record_id) DO NOTHING`,
				alarmID, ts, env.AssetID, env.Domain, env.SystemID,
				nullText(env.DMAID), env.Alarm, coalesce(env.Quality, "GOOD"),
			)
		}
	}
}

// history queries telemetry with the v3 column names.
func (db *DB) history(ctx context.Context, assetID, domain string, limit int) ([]TelemetryEnvelope, error) {
	if db == nil {
		return nil, nil
	}
	q := `SELECT asset_id, domain_code, system_id, timestamp_utc,
	             quality_code, alarm_state, payload
	      FROM telemetry WHERE 1=1`
	args := []any{}
	n := 1
	if assetID != "" {
		q += fmt.Sprintf(" AND asset_id=$%d", n)
		args = append(args, assetID)
		n++
	}
	if domain != "" {
		q += fmt.Sprintf(" AND domain_code=$%d", n)
		args = append(args, domain)
		n++
	}
	q += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d", n)
	args = append(args, limit)

	rows, err := db.pool.Query(ctx, q, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var out []TelemetryEnvelope
	for rows.Next() {
		var e TelemetryEnvelope
		var raw []byte
		var ts *time.Time
		if err := rows.Scan(&e.AssetID, &e.Domain, &e.SystemID, &ts, &e.Quality, &e.Alarm, &raw); err != nil {
			continue
		}
		if raw != nil {
			json.Unmarshal(raw, &e.Values)
		}
		if ts != nil {
			e.OccurredAt = ts.UTC().Format(time.RFC3339)
		}
		out = append(out, e)
	}
	return out, nil
}

// latestAssets reads the current state of every asset from asset_latest.
// Optionally filtered by domain. Returns TelemetryEnvelope slice so existing
// callers and the WS broadcast format stay unchanged.
func (db *DB) latestAssets(ctx context.Context, domain string) ([]TelemetryEnvelope, error) {
	if db == nil {
		return nil, nil
	}
	q := `SELECT asset_id, system_id, domain_code, timestamp_utc,
	             quality_code, alarm_state, dma_id, site_id, stream_topic, payload, updated_at
	      FROM asset_latest`
	args := []any{}
	if domain != "" {
		q += ` WHERE domain_code = $1`
		args = append(args, domain)
	}
	q += ` ORDER BY updated_at DESC`

	rows, err := db.pool.Query(ctx, q, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var out []TelemetryEnvelope
	for rows.Next() {
		var e TelemetryEnvelope
		var raw []byte
		var ts *time.Time
		var updatedAt time.Time
		if err := rows.Scan(
			&e.AssetID, &e.SystemID, &e.Domain, &ts,
			&e.Quality, &e.Alarm, &e.DMAID, &e.SiteID, &e.StreamTopic, &raw, &updatedAt,
		); err != nil {
			continue
		}
		if ts != nil {
			e.OccurredAt = ts.UTC().Format(time.RFC3339)
		}
		if raw != nil {
			json.Unmarshal(raw, &e.Values)
		}
		out = append(out, e)
	}
	return out, nil
}

// latestAlarms reads only assets currently in a non-NORMAL alarm state.
func (db *DB) latestAlarms(ctx context.Context) ([]map[string]any, error) {
	if db == nil {
		return []map[string]any{}, nil
	}
	rows, err := db.pool.Query(ctx, `
		SELECT asset_id, domain_code, system_id, timestamp_utc,
		       quality_code, alarm_state, dma_id, site_id, payload
		FROM asset_latest
		WHERE alarm_state <> 'NORMAL'
		ORDER BY updated_at DESC`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var out []map[string]any
	for rows.Next() {
		var assetID, domain, sysID, quality, alarm string
		var dmaID, siteID *string
		var ts *time.Time
		var raw []byte
		if err := rows.Scan(&assetID, &domain, &sysID, &ts, &quality, &alarm, &dmaID, &siteID, &raw); err != nil {
			continue
		}
		var values map[string]any
		if raw != nil {
			json.Unmarshal(raw, &values)
		}
		var tsStr string
		if ts != nil {
			tsStr = ts.UTC().Format(time.RFC3339)
		}
		out = append(out, map[string]any{
			"asset_id":    assetID,
			"alarm_state": alarm,
			"quality":     quality,
			"domain":      domain,
			"system_id":   sysID,
			"dma_id":      dmaID,
			"site_id":     siteID,
			"ts":          tsStr,
			"values":      values,
		})
	}
	if out == nil {
		out = []map[string]any{}
	}
	return out, nil
}

// stats returns aggregate counts from the DB tables.
func (db *DB) stats(ctx context.Context) (map[string]any, error) {
	if db == nil {
		return map[string]any{"total": 0, "alarms": 0, "domains": 0, "live_assets": 0}, nil
	}
	var total, domains int
	var liveAssets, liveAlarms int
	db.pool.QueryRow(ctx, `SELECT COUNT(*) FROM telemetry`).Scan(&total)
	db.pool.QueryRow(ctx, `SELECT COUNT(DISTINCT domain_code) FROM asset_latest`).Scan(&domains)
	db.pool.QueryRow(ctx, `SELECT COUNT(*) FROM asset_latest`).Scan(&liveAssets)
	db.pool.QueryRow(ctx, `SELECT COUNT(*) FROM asset_latest WHERE alarm_state <> 'NORMAL'`).Scan(&liveAlarms)
	return map[string]any{
		"total":       total,
		"alarms":      liveAlarms,
		"domains":     domains,
		"live_assets": liveAssets,
	}, nil
}

// ── helpers ───────────────────────────────────────────────────────────────────

func coalesce(s, def string) string {
	if s == "" {
		return def
	}
	return s
}

func nullText(s string) any {
	if s == "" {
		return nil
	}
	return s
}

func nullFloat(f float64) any {
	if f == 0 {
		return nil
	}
	return f
}

func nullInt(i int) any {
	if i == 0 {
		return nil
	}
	return i
}
