#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#  nvidia-installer
#
#  Copyright © 2013-2017 Antergos
#
#  This file is part of Antergos.
#
#  Antergos is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  Antergos is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  The following additional terms are in effect as per Section 7 of the license:
#
#  The preservation of all legal notices and author attributions in
#  the material or in the Appropriate Legal Notices displayed
#  by works containing it is required.
#
#  You should have received a copy of the GNU General Public License
#  along with Antergos; If not, see <http://www.gnu.org/licenses/>.

# 2019 Modified for EndeavourOS.
# Maintainer: @manuel
# Contributor: joekamprad <joekamprad@endeavouros.com>

""" nvidia-installer eases nvidia video driver installation """

import argparse
import getpass
import os
import logging
import subprocess
import sys

VERSION = "2.1"  # will be overwritten by PKGBUILD
TEST = True

# Separate logs for root and others
if os.getuid() == 0:
    LOG_FILE = "/tmp/nvidia-installer_r.log"
else:
    LOG_FILE = "/tmp/nvidia-installer.log"

IDS_PATH = "/var/lib/pci"

YELLOW = '\033[93m'
GREEN = '\033[92m'
RED = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'

DEVICES = {
    "nvidia-dkms" : []
}

CLASS_ID = "0x03"
VENDOR_ID = "0x10de"

PACKAGES = {
    "nouveau": ["xf86-video-nouveau", "mesa"],
    "nvidia-dkms": ["nvidia-dkms", "libvdpau", "nvidia-settings"],
    "bumblebee": ["bumblebee", "mesa", "xf86-video-intel",
                  "nvidia-dkms", "virtualgl", "nvidia-settings", "bbswitch-dkms"]}

PACKAGES_LTS = {
    "nouveau": [],
    "nvidia-dkms": [],
    "bumblebee": []}

PACKAGES_X86_64 = {
    "nouveau": ["lib32-mesa"],
    "nvidia-dkms": ["lib32-nvidia-utils", "lib32-libvdpau"],
    "bumblebee": ["lib32-nvidia-utils", "lib32-virtualgl", "lib32-mesa"]}

CONFLICTS = {
    "nouveau": ["nvidia-dkms", "nvidia", "nvidia-lts" "nvidia-utils", "nvidia-libgl",
                "bumblebee", "virtualgl", "nvidia-settings", "bbswitch", "bbswitch-lts", "bbswitch-dkms"],
    "nvidia-dkms": ["xf86-video-nouveau", "nvidia-lts", "nvidia"],
    "bumblebee": ["xf86-video-nouveau", "nvidia-libgl"]}

CONFLICTS_X86_64 = {
    "nouveau": ["lib32-nvidia-libgl", "lib32-nvidia-utils", "lib32-virtualgl"],
    "nvidia-dkms": ["lib32-nvidia-libgl"],
    "bumblebee": ["lib32-nvidia-libgl"]}


def get_class_vendor_product(line):
    """ Parses lspci -n output line """
    line = line.split()
    class_id = "0x{}".format(line[1].rstrip(":")[0:2])
    dev = line[2].split(":")
    vendor_id = "0x{}".format(dev[0])
    product_id = "0x{}".format(dev[1])
    return (class_id, vendor_id, product_id)


def check_device():
    """ Tries to guess if a device suitable for this driver is present """
    try:
        cmd = ["/usr/bin/lspci", "-n"]
        lines = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        lines = lines.decode().split("\n")
    except subprocess.CalledProcessError as err:
        log_warning("Cannot detect hardware components : {}".format(err.output.decode()))
        return None

    drivers = []
    for line in lines:
        if line:
            class_id, vendor_id, product_id = get_class_vendor_product(line)

            if class_id == CLASS_ID and vendor_id == VENDOR_ID:
                for driver in DEVICES:
                    if product_id in DEVICES[driver]:
                        drivers.append(driver)
                        #return drivers
    return drivers


def choose_driver(drivers):
    """ There is more than one suitable driver. Let user choose """

    """ No, always choose nvidia-dkms if it is compatible """
    if drivers[0] == "nvidia-dkms" or drivers[1] == "nvidia-dkms":
        log_info("Chose nvidia-dkms out of possible {0} and {1}.".format(drivers[0], drivers[1]))
        return "nvidia-dkms"

    print("There is more than one suitable driver. Please choose:")
    index = 0
    for index, driver in enumerate(drivers):
        print("{0}. {1}".format(index + 1, driver))
    print("{0}. Quit".format(index + 2))

    choose = -1
    while choose == -1:
        choose = input("Your selection: ")
        choose = int(choose) - 1

    if choose >= len(drivers):
        sys.exit(0)

    return drivers[choose]


def install(driver):
    """ Performs packages installation """

    packages = PACKAGES[driver]
    conflicts = CONFLICTS[driver]

    if os.uname()[-1] == "x86_64":
        packages.extend(PACKAGES_X86_64[driver])
        conflicts.extend(CONFLICTS_X86_64[driver])

    if os.path.exists('/boot/vmlinuz-linux-lts'):
        packages.extend(PACKAGES_LTS[driver])

    log_info("Removing conflicting packages...")
    installed_packages = get_installed_packages()

    cmd = ["pacman", "-Rs", "--noconfirm", "--noprogressbar", "--nodeps"]
    for conflict in conflicts:
        if conflict in installed_packages:
            if TEST:
                log_info(" ".join(cmd + [conflict]))
            else:
                try:
                    subprocess.check_output(cmd + [conflict], stderr=subprocess.STDOUT)
                except subprocess.CalledProcessError as err:
                    msg = "Cannot remove conflicting package {0}: {1}"
                    log_error(msg.format(conflict, err.output.decode()))
                    return False

    log_info("Downloading and installing driver packages, please wait as this may take a few minutes...")
    cmd = ["pacman", "-Sqy", "--noconfirm", "--noprogressbar"]
    cmd.extend(packages)
    if TEST:
        log_info(" ".join(cmd))
    else:
        try:
            subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            msg = "Cannot install required packages: {}"
            log_error(msg.format(err.output.decode()))
            return False

    return True


def get_installed_packages():
    """ Gets a list of all installed packages """
    installed_packages = []
    try:
        cmd = ["pacman", "-Q"]
        res = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode().split('\n')
        for line in res:
            line = line.split()
            if line:
                installed_packages.append(line[0])
    except subprocess.CalledProcessError as err:
        msg = "Cannot check installed packages (pacman -Q): {}"
        log_warning(msg.format(err.output.decode()))
    except OSError as err:
        msg = "Cannot check installed packages (pacman -Q): {}"
        log_warning(msg.format(err))
    return installed_packages


def add_user_to_group(user, group):
    """ Adds user to group in system """
    log_info("Adding user {0} to {1} group...".format(user, group))
    cmd = ["gpasswd", "-a", user, group]
    log_info(" ".join(cmd))
    if not TEST:
        try:
            subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            msg = "Cannot add user {0} to the {1} group: {2}"
            log_warning(msg.format(user, group, err.output.decode()))

def update_nvidia_cards_database():
    """ Update Nvidia cards database """
    log_info("Updating Nvidia graphics cards database...")
    if TEST:
        IDS_PATH = "/tmp"
        cmd = ["nvidia-installer-update-db", "--tmpdb"]
        log_info(" ".join(cmd))
    else:
        cmd = ["nvidia-installer-update-db"]

    try:
        subprocess.check_output(cmd)
    except subprocess.CalledProcessError as err:
        return False

    return True

def get_user():
    """ Gets current username """
    user = None
    try:
        user = os.environ['SUDO_USER']
    except KeyError:
        user = getpass.getuser()
    return user


def enable_service(service, enable=True):
    """ Enables service using systemctl """
    if not service.endswith(".service"):
        service += ".service"

    if os.path.exists("/usr/lib/systemd/system/{}".format(service)):
        cmd = ["systemctl"]
        if enable:
            log_info("Enabling {} service...".format(service))
            cmd += ["enable", service]
        else:
            log_info("Disabling {} service...".format(service))
            cmd += ["disable", service]
        log_info(" ".join(cmd))
        if not TEST:
            try:
                subprocess.check_output(cmd, stderr=subprocess.STDOUT)
            except subprocess.CalledProcessError as err:
                msg = "Cannot enable/disable {0} service: {1}"
                log_warning(msg.format(service, err.output.decode()))


def patch_nvidia_settings(patch=True):
    """ Fixes nvidia-settings.desktop """
    desktop_path = "/usr/share/applications/nvidia-settings.desktop"
    if os.path.exists(desktop_path):
        exec1 = "Exec=/usr/bin/nvidia-settings"
        exec2 = "Exec=optirun -b none /usr/bin/nvidia-settings -c :8"
        if patch:
            log_info("Patching {}...".format(desktop_path))
            sed = "s|{0}|{1}|"
        else:
            log_info("Unpatching {}...".format(desktop_path))
            sed = "s|{1}|{0}|"
        sed = sed.format(exec1, exec2)
        cmd = ["/usr/bin/sed", "-i", sed, desktop_path]
        if TEST:
            log_info(" ".join(cmd))
        else:
            try:
                subprocess.check_output(cmd, stderr=subprocess.STDOUT)
            except subprocess.CalledProcessError as err:
                msg = "Cannot modify {0} file : {1}"
                log_warning(msg.format(desktop_path, err.output.decode()))

def nvidia_kernel_params():
    """ Fix Nvidia kernel parameters """
    log_info("Patching kernel line with nvidia-drm.modeset=1")
    cmd = ["nvidia-installer-kernel-para"]
    if TEST:
        log_info(" ".join(cmd))
    else:
        try:
            subprocess.check_output(cmd)
        except subprocess.CalledProcessError as err:
            return False

    return True


def remove_file(path):
    """ Removes file if exists """
    log_info("Removing {} file...".format(path))
    if not TEST:
        if os.path.exists(path):
            os.remove(path)
        else:
            log_info("{} not found. That's ok.".format(path))


def post_install(driver):
    """ Run post installation actions here """

    nvidia_conf_path = "/etc/X11/xorg.conf.d/20-nvidia.conf"

    if driver == "bumblebee":
        user = get_user()
        if user != "root":
            add_user_to_group(user, "bumblebee")
            add_user_to_group(user, "video")
        else:
            msg = "NOT adding user 'root' to bumblebee or video groups! " \
                "Remember to add your users to those groups"
            log_warning(msg)

        enable_service("bumblebeed.service")
        patch_nvidia_settings()
    else:
        # not bumblebee
        enable_service("bumblebeed.service", False)
        patch_nvidia_settings(False)
        # add kernel parameter (disabled for now!)
        if driver.startswith("nvidia"):
            nvidia_kernel_params()

    if not driver.startswith("nvidia"):
        # bumblebee and nouveau
        remove_file(nvidia_conf_path)

    fix_mkinitcpio()


def fix_mkinitcpio():
    """ Removes nouveau and nvidia from MODULES line in mkinitcpio.conf """

    mkinitcpio_path = "/etc/mkinitcpio.conf"

    # Read mkinitcpio.conf
    with open(mkinitcpio_path) as mkinitcpio_file:
        mklines = mkinitcpio_file.readlines()

    modified = False
    new_mklines = []
    for line in mklines:
        # Remove nouveau and nvidia from MODULES line (just in case)
        if line.startswith("MODULES"):
            if "nouveau" in line or "nvidia" in line:
                old_line = line
                line = line.replace("nouveau", "")
                line = line.replace("nvidia", "")
                line = line.replace(",,", ",")
                line = line.replace('",', '"')
                line = line.replace(',"', '"')
                msg = "{0} will be changed to {1} in {2}"
                log_info(msg.format(old_line.strip("\n"), line.strip("\n"), mkinitcpio_path))
                modified = True
        new_mklines.append(line)

    if modified and not TEST:
        try:
            with open(mkinitcpio_path, "w") as mkinitcpio_file:
                for line in new_mklines:
                    mkinitcpio_file.write(line)

            cmd = ["mkinitcpio", "-p", "linux"]
            res = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
            res = res.decode().split('\n')

            if os.path.exists('/boot/vmlinuz-linux-lts'):
                cmd = ["mkinitcpio", "-p", "linux-lts"]
                res = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
                res = res.decode().split('\n')
        except subprocess.CalledProcessError as err:
            msg = "Cannot run {0}: {1}, please run it manually before rebooting!"
            log_warning(msg.format(' '.join(cmd), err))
        except PermissionError:
            log_error("This script must be run with administrative privileges!")


def parse_options():
    """ Parse command line options """
    desc = "EndeavourOS Nvidia Installer v{0}".format(VERSION)
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument(
        "-b", "--bumblebee",
        help="For Nvidia Optimus cards (Bumblebee + proprietary Nvidia drivers)",
        action='store_true')

    parser.add_argument(
        "-f", "--force",
        help="Force driver installation even if a nvidia card is not detected",
        action="store_true")

    parser.add_argument(
        "-t", "--test",
        help="Test mode. Nothing in your system will be modified",
        action="store_true")

    parser.add_argument(
        "-q", "--quiet",
        help="Supress log messages",
        action="store_true")

    parser.add_argument(
        "-n", "--nouveau",
        help="Restores nouveau (open) nvidia driver",
        action="store_true")

    return parser.parse_args()


def setup_logging(cmd_line):
    """ Configure our logger """
    logger = logging.getLogger()

    logger.handlers = []

    log_level = logging.INFO

    logger.setLevel(log_level)

    #context_filter = ContextFilter()
    #logger.addFilter(context_filter.filter)

    # Log format
    formatter = logging.Formatter(
        fmt="%(asctime)s [%(levelname)s]: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S")

    # remove old log file if one exists
    if not cmd_line.test:
        remove_file(LOG_FILE)

    # File logger
    try:
        file_handler = logging.FileHandler(LOG_FILE, mode='w')
        file_handler.setLevel(log_level)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    except PermissionError as permission_error:
        log_error("Can't open {0} : {1}".format(LOG_FILE, permission_error))

    # Stdout logger
    if not cmd_line.quiet:
        stream_handler = logging.StreamHandler()
        stream_handler.setLevel(log_level)
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

    log_info("EndeavourOS Nvidia Installer v{}".format(VERSION))
    log_info("All logs will be stored in {}".format(LOG_FILE))


def log_error(msg):
    """ Log error message """
    msg = RED + msg + ENDC
    logging.error(msg)


def log_warning(msg):
    """ Log warning message """
    msg = YELLOW + msg + ENDC
    logging.warning(msg)


def log_info(msg):
    """ Log information message """
    msg = GREEN + msg + ENDC
    logging.info(msg)


def load_ids_file(path):
    """ Loads pci device numbers from a ids file """
    with open(path, 'r') as ids_file:
        lines = ids_file.readlines()

    for index, line in enumerate(lines):
        lines[index] = line.lower()

    ids = []
    for line in lines:
        ids.extend(line.split())

    for index, pci_id in enumerate(ids):
        ids[index] = "0x" + pci_id

    return ids


def load_ids():
    """ Load all nvidia devices pci ids """
    for item in os.listdir(IDS_PATH):
        if item.endswith('nvidia.ids') and "nvidia" in item:
            key = item[:-4] + "-dkms"
            path = os.path.join(IDS_PATH, item)
            DEVICES[key] = load_ids_file(path)


def main():
    """ Main program """
    # Command line options
    cmd_line = parse_options()

    setup_logging(cmd_line)

    global TEST
    TEST = cmd_line.test

    if not TEST and os.getuid() != 0:
        log_error("This installer must be run with administrative privileges (use sudo)")
        return

    if TEST and os.getuid() == 0:
        log_error("Do not run this installer in test mode with administrative privileges")
        return

    if TEST:
        log_info("Running the installer in testing mode...")

    if update_nvidia_cards_database() != True:
        if TEST:
            log_warning("Cannot create Nvidia cards database in /tmp")
        else:
            log_warning("Cannot update Nvidia cards database in /var/lib/pci")

    try:
        load_ids()
    except FileNotFoundError:
        log_error("Cannot load ids files from /var/lib/pci")
        return

    driver = None

    if cmd_line.bumblebee:
        driver = "bumblebee"
    elif cmd_line.nouveau:
        driver = "nouveau"
    else:
        drivers = check_device()
        if drivers:
            if len(drivers) > 1:
                # We have more than one suitable driver
                driver = choose_driver(drivers)
            else:
                driver = drivers[0]

    if not driver and cmd_line.force:
        msg = (
            "Your graphics card is not in our database. "
            "Forcing nvidia-dkms driver installation anyways...")
        log_warning(msg)
        msg = (
            "\n"
            "  An older card needs an older driver installed before reboot.\n"
            "  See https://www.nvidia.com/en-us/drivers/unix about supported driver versions\n"
            "  and install the correct driver e.g. with command:\n"
            "      yay -S nvidia-470xx-dkms\n"
            "  (the above example was for the 470 series driver).")
        log_warning(msg)
        driver = "nvidia-dkms"

    if driver:
        log_info("Installing {} driver...".format(driver))
        if install(driver):
            post_install(driver)
            # The nvidia package contains a file which blacklists the nouveau
            # module, so rebooting is necessary.
            if TEST:
                msg = "Installation finished. Nothing has been modified as testing mode was ON."
                log_info(msg)
            else:
                log_info("Installation finished. You need to reboot now!")
    else:
        log_error("Couldn't find a driver suitable for your graphics card.")
        log_error("If you have an older nvidia card, you may use the --force option to install nvidia-dkms and follow further instructions.")


if __name__ == "__main__":
    main()
