// yt - A fully featured command line YouTube client
//
// Copyright (C) 2024 Benedikt Peetz <benedikt.peetz@b-peetz.de>
// SPDX-License-Identifier: GPL-3.0-or-later
//
// This file is part of Yt.
//
// You should have received a copy of the License along with this program.
// If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.

use std::{fmt::Display, str::FromStr};

use error::BytesError;

const B: u64 = 1;

const KIB: u64 = 1024 * B;
const MIB: u64 = 1024 * KIB;
const GIB: u64 = 1024 * MIB;
const TIB: u64 = 1024 * GIB;
const PIB: u64 = 1024 * TIB;

const KB: u64 = 1000 * B;
const MB: u64 = 1000 * KB;
const GB: u64 = 1000 * MB;
const TB: u64 = 1000 * GB;

pub mod error;
#[cfg(feature = "serde")]
pub mod serde;

#[derive(Clone, Copy)]
pub struct Bytes(u64);

impl Bytes {
    pub fn as_u64(self) -> u64 {
        self.0
    }
    pub fn new(v: u64) -> Self {
        Self(v)
    }
}

impl FromStr for Bytes {
    type Err = BytesError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let s = s.chars().filter(|s| !s.is_whitespace()).collect::<String>();

        let (whole_number, s): (u64, &str) = {
            let number = s.chars().take_while(|x| x.is_numeric()).collect::<String>();

            (number.parse()?, &s[number.len()..])
        };

        let (decimal_number, s, raise_factor) = {
            if s.starts_with('.') {
                let s_str = s
                    .chars()
                    .skip(1) // the decimal point
                    .take_while(|x| x.is_numeric())
                    .collect::<String>();

                let s_num = s_str.parse::<u64>()?;

                (s_num, &s[s_str.len()..], s_str.len() as u32)
            } else {
                (0u64, s, 0)
            }
        };

        let number = (whole_number * (10u64.pow(raise_factor))) + decimal_number;

        let extension = s.chars().skip_while(|x| x.is_numeric()).collect::<String>();

        let output = match extension.to_lowercase().as_str() {
            "" => number,
            "b" => number * B,
            "kib" => number * KIB,
            "mib" => number * MIB,
            "gib" => number * GIB,
            "tib" => number * TIB,
            "kb" => number * KB,
            "mb" => number * MB,
            "gb" => number * GB,
            "tb" => number * TB,
            other => return Err(BytesError::NotYetSupported(other.to_owned())),
        };

        Ok(Self(output / (10u64.pow(raise_factor))))
    }
}

impl Display for Bytes {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let num = self.0;

        match num {
            0..KIB => f.write_fmt(format_args!("{} {}", num, "B"))?,
            KIB..MIB => f.write_fmt(format_args!(
                "{} {}",
                precision_f64((num as f64) / (KIB as f64), 3),
                "KiB"
            ))?,
            MIB..GIB => f.write_fmt(format_args!(
                "{} {}",
                precision_f64((num as f64) / (MIB as f64), 3),
                "MiB"
            ))?,
            GIB..TIB => f.write_fmt(format_args!(
                "{} {}",
                precision_f64((num as f64) / (GIB as f64), 3),
                "GiB"
            ))?,
            TIB..PIB => f.write_fmt(format_args!(
                "{} {}",
                precision_f64((num as f64) / (TIB as f64), 3),
                "TiB"
            ))?,
            PIB.. => todo!(),
        }

        Ok(())
    }
}

// taken from this stack overflow question: https://stackoverflow.com/a/76572321
/// Round to significant digits (rather than digits after the decimal).
///
/// Not implemented for `f32`, because such an implementation showed precision
/// glitches (e.g. `precision_f32(12300.0, 2) == 11999.999`), so for `f32`
/// floats, convert to `f64` for this function and back as needed.
///
/// Examples:
/// ```
///# fn main() {
///# use bytes::precision_f64;
///   assert_eq!(precision_f64(1.2300, 2), 1.2f64);
///   assert_eq!(precision_f64(1.2300_f64, 2), 1.2f64);
///   assert_eq!(precision_f64(1.2300_f32 as f64, 2), 1.2f64);
///   assert_eq!(precision_f64(1.2300_f32 as f64, 2) as f32, 1.2f32);
///# }
/// ```
pub fn precision_f64(x: f64, decimals: u32) -> f64 {
    if x == 0. || decimals == 0 {
        0.
    } else {
        let shift = decimals as i32 - x.abs().log10().ceil() as i32;
        let shift_factor = 10_f64.powi(shift);

        (x * shift_factor).round() / shift_factor
    }
}

#[cfg(test)]
mod tests {
    use super::{Bytes, GIB};

    #[test]
    fn parsing() {
        let input: Bytes = "20 GiB".parse().unwrap();
        let expected = 20 * GIB;

        assert_eq!(expected, input.0);
    }
    #[test]
    fn parsing_not_round() {
        let input: Bytes = "2.34 GiB".parse().unwrap();
        let expected = "2.34 GiB";

        assert_eq!(expected, input.to_string().as_str());
    }

    #[test]
    fn round_trip_1kib() {
        let input = "1 KiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }
    #[test]
    fn round_trip_2kib() {
        let input = "2 KiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }

    #[test]
    fn round_trip_1mib() {
        let input = "1 MiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }
    #[test]
    fn round_trip_2mib() {
        let input = "2 MiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }

    #[test]
    fn round_trip_1gib() {
        let input = "1 GiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }
    #[test]
    fn round_trip_2gib() {
        let input = "2 GiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }

    #[test]
    fn round_trip() {
        let input = "20 TiB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!(input.to_owned(), parsed.to_string());
    }

    #[test]
    fn round_trip_decmimal() {
        let input = "20 TB";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!("18.2 TiB", parsed.to_string());
    }
    #[test]
    fn round_trip_1b() {
        let input = "1";
        let parsed: Bytes = input.parse().unwrap();

        assert_eq!("1 B", parsed.to_string());
    }
}