from util.util import (
    get_hwnetid,
    dict_align_str,
)
from util.util_i2c import (
    i2c_read,
    i2c_write,
)

import datetime
import struct
from collections import namedtuple
from enum import Enum

# ICS SFP MODULE SLAVE ADDRESSES
ICS_SFP_SLAVE_ADDR_MSA = 0x50
ICS_SFP_SLAVE_ADDR_DMI = 0x51
ICS_SFP_SLAVE_ADDR_MDIO_BRIDGE = 0x56
ICS_SFP_SLAVE_ADDR_MDIO_BRIDGE_TECHNICA = 0x40
ICS_SFP_SLAVE_ADDR_ICS_CONTROL = 0x1C
ICS_SFP_SLAVE_ADDR_ICS_BOOTLOADER = 0x57
# ICS SFP MODULE BOOTLOADER COMMANDS
ICS_SFP_BL_GET_STATUS = 0x00
ICS_SFP_BL_GET_VERSION = 0x12
ICS_SFP_BL_SEND_FW = 0x14
ICS_SFP_BL_RESET_TO_BL = 0x15
ICS_SFP_BL_FLASH_VALIDATE = 0x16
ICS_SFP_BL_FLASH_INIT = 0x17
ICS_SFP_BL_FLASH_START = 0x18
ICS_SFP_BL_FLASH_ERASE = 0x20
ICS_SFP_BL_RESET_TO_APP = 0x21
ICS_SFP_BL_VALIDATE_SW_VERS = 0x27
ICS_SFP_BL_GET_ERROR = 0x28
# ICS SFP MODULE ICS CONTROL BYTE OFFSETS
ICS_SFP_CONFIG_REG_SLEEP_OFFSET = 0
ICS_SFP_CONFIG_REG_BL_OFFSET = 1
ICS_SFP_CONFIG_REG_MDIO_SPEED_OFFSET = 2
ICS_SFP_CONFIG_REG_MDIO_LINKMODE_OFFSET = 3
ICS_SFP_CONFIG_REG_MDIO_PHYMODE_OFFSET = 4
ICS_SFP_CONFIG_REG_MDIO_AUTONEG_OFFSET = 5
ICS_SFP_CONFIG_REG_MDIO_ENABLE_OFFSET = 6
ICS_SFP_CONFIG_REG_PHY_TEMPERATURE_OFFSET = 7
ICS_SFP_CONFIG_REG_FW_MINOR_VERS_OFFSET = 8
ICS_SFP_CONFIG_REG_FW_MAJOR_VERS_OFFSET = 9
ICS_SFP_CONFIG_REG_WRITE_MACSEC_CFG_OFFSET = 10
ICS_SFP_CONFIG_REG_CONFIGURE_MACSEC_RULE_OFFSET = 11
ICS_SFP_CONFIG_REG_CONFIGURE_MACSEC_MAP_OFFSET = 12
ICS_SFP_CONFIG_REG_CONFIGURE_MACSEC_SECY_OFFSET = 13
ICS_SFP_CONFIG_REG_CONFIGURE_MACSEC_SC_OFFSET = 14
ICS_SFP_CONFIG_REG_CONFIGURE_MACSEC_SA_OFFSET = 15
# vendor fields in MSA table
ICS_SFP_VENDOR_BRIDGE_ADDR_OVERRIDE_N_MASK = 0x01
ICS_SFP_VENDOR_SOFT_OPTIONS_OVERRIDE_N_MASK = 0x02
ICS_SFP_VENDOR_ROTARY_SWITCH_OVERRIDE_N_MASK = 0x04


class SFP_ICS_CONFIG_SUBCOMMANDS(Enum):
    CMD_WRITE_SLEEP = 0
    CMD_WRITE_BOOTLOADER = 1
    CMD_WRITE_MDIO_PHY_UPDATE = 2
    CMD_WRITE_TC10_WAKEUP = 3
    CMD_RESERVED_4 = 4
    CMD_RESERVED_5 = 5
    CMD_RESERVED_6 = 6
    CMD_READ_PHY_TEMPERATURE = 7
    CMD_READ_FW_MINOR = 8
    CMD_READ_FW_MAJOR = 9
    CMD_WRITE_MACSEC_UPDATE = 10
    CMD_WRITE_MACSEC_CFG_RULE = 11
    CMD_WRITE_MACSEC_CFG_MAP = 12
    CMD_WRITE_MACSEC_CFG_SECY = 13
    CMD_WRITE_MACSEC_CFG_SC = 14
    CMD_WRITE_MACSEC_CFG_SA = 15


def transmit_i2c_ICSSFP_UPDATE_MACSEC(device, netid, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_UPDATE.value],
        2,
        data,
    )


def transmit_i2c_ICSSFP_CONFIG_MACSEC_RULE(device, netid, len, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_CFG_RULE.value],
        len,
        data,
    )


def transmit_i2c_ICSSFP_CONFIG_MACSEC_MAP(device, netid, len, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_CFG_MAP.value],
        len,
        data,
    )


def transmit_i2c_ICSSFP_CONFIG_MACSEC_SECY(device, netid, len, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_CFG_SECY.value],
        len,
        data,
    )


def transmit_i2c_ICSSFP_CONFIG_MACSEC_SC(device, netid, len, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_CFG_SC.value],
        len,
        data,
    )


def transmit_i2c_ICSSFP_CONFIG_MACSEC_SA(device, netid, len, data):
    return i2c_write(
        device,
        netid,
        ICS_SFP_SLAVE_ADDR_ICS_CONTROL,
        1,
        [SFP_ICS_CONFIG_SUBCOMMANDS.CMD_WRITE_MACSEC_CFG_SA.value],
        len,
        data,
    )


def sfp_msa_decode(data):
    table = namedtuple(
        "msa",
        "identifier ext_identifier connector transceiver encoding br_nominal l1 l2 l3 l4 l5 vendor_name vendor_oui vendor_pn vendor_rev cc_base options br_max br_min vendor_sn date_year date_month date_day date_lot diag_mon_type enh_options sff8472_compl cc_ext ics_mdio_bridge_addr vendor_data ics_pcb_serial ics_app_id ics_overrides",
    )
    data = bytes(data)
    x = struct.unpack(
        ">BBBQBB1xBBBBB1x16s1x3s16s4s3xBHBB16s2s2s2s2sBBBBB13s16sBB", data
    )
    n = 16 * 2
    hex_str = data.hex()
    hex_str = "\n".join(hex_str[i : i + n] for i in range(0, len(hex_str), n))
    print("MSA table raw:\n" + hex_str)
    table = table._asdict(table._make(x))
    try:
        # verify CC_BASE, 8-bit sum of bytes 64-94
        cc_base = 0
        for i in range(0, 63):
            cc_base += data[i]
        table["cc_base_valid"] = 1 if (cc_base & 0xFF) == table["cc_base"] else 0
        # verify CC_EXT, 8-bit sum of bytes 64-94
        cc_ext = 0
        for i in range(64, 95):
            cc_ext += data[i]
        table["cc_ext_valid"] = 1 if (cc_ext & 0xFF) == table["cc_ext"] else 0
        # convert types of some items
        table["vendor_name"] = (
            table["vendor_name"].decode("utf-8", errors="ignore").strip()
        )
        table["vendor_pn"] = table["vendor_pn"].decode("utf-8", errors="ignore").strip()
        table["vendor_rev"] = (
            table["vendor_rev"].decode("utf-8", errors="ignore").strip()
        )
        table["vendor_sn"] = table["vendor_sn"].decode("utf-8", errors="ignore").strip()
        table["ics_pcb_serial"] = (
            table["ics_pcb_serial"].decode("utf-8", errors="ignore").strip()
        )
        year = table["date_year"].decode("utf-8", errors="ignore").strip()
        if year.isdigit():
            table["date_year"] = int(year) + 2000
        month = table["date_month"].decode("utf-8", errors="ignore").strip()
        if month.isdigit():
            table["date_month"] = int(month)
        day = table["date_day"].decode("utf-8", errors="ignore").strip()
        if day.isdigit():
            table["date_day"] = int(day)
    except Exception:
        pass
    return table


def sfp_query_module(device, netid):
    """
    Look for an SFP module on an I2C network.
    """
    data = i2c_read(device, netid, ICS_SFP_SLAVE_ADDR_MSA, 1, [0x00], 128, [0] * 128)
    if data is None:
        return None

    msa = sfp_msa_decode(data[1:])
    return msa


BL_FLAGS_FW_VALID_OFFSET = 0
BL_FLAGS_APP_ERROR_OFFSET = 1
BL_FLAGS_ERASE_IN_PROGRESS_OFFSET = 2
BL_FLAGS_FLASH_IN_PROGRESS_OFFSET = 3
BL_FLAGS_RESET_IN_PROGRESS_OFFSET = 4
BL_FLAGS_BL_READY_OFFSET = 5
BL_FLAGS_FLASH_READY_OFFSET = 6
BL_FLAGS_ENABLE_XTEA_OFFSET = 7


def sfp_ics_send_macsec_rule(device, netid, rule):
    data = []
    data.append(rule["index"])
    for x in range(6):
        data.append(rule["key_MAC_DA"][x])
    for x in range(6):
        data.append(rule["mask_MAC_DA"][x])
    for x in range(6):
        data.append(rule["key_MAC_SA"][x])
    for x in range(6):
        data.append(rule["mask_MAC_SA"][x])
    for x in bytearray(rule["key_Ethertype"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["mask_Ethertype"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["key_outer1"]["vlanTag"]["VID"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["key_outer1"]["vlanTag"]["PRI_CFI"])
    for x in bytearray(rule["key_outer1"]["mpls"]["MPLS_label"].to_bytes(4, "little")):
        data.append(x)
    data.append(rule["key_outer1"]["mpls"]["exp"])
    for x in bytearray(rule["mask_outer1"]["vlanTag"]["VID"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["mask_outer1"]["vlanTag"]["PRI_CFI"])
    for x in bytearray(rule["mask_outer1"]["mpls"]["MPLS_label"].to_bytes(4, "little")):
        data.append(x)
    data.append(rule["mask_outer1"]["mpls"]["exp"])
    for x in bytearray(rule["key_outer2"]["vlanTag"]["VID"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["key_outer2"]["vlanTag"]["PRI_CFI"])
    for x in bytearray(rule["key_outer2"]["mpls"]["MPLS_label"].to_bytes(4, "little")):
        data.append(x)
    data.append(rule["key_outer2"]["mpls"]["exp"])
    for x in bytearray(rule["mask_outer2"]["vlanTag"]["VID"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["mask_outer2"]["vlanTag"]["PRI_CFI"])
    for x in bytearray(rule["mask_outer2"]["mpls"]["MPLS_label"].to_bytes(4, "little")):
        data.append(x)
    data.append(rule["mask_outer2"]["mpls"]["exp"])
    for x in bytearray(rule["key_bonus_data"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["mask_bonus_data"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["key_tag_match_bitmap"])
    data.append(rule["mask_tag_match_bitmap"])
    data.append(rule["key_packet_type"])
    data.append(rule["mask_packet_type"])
    for x in bytearray(rule["key_inner_vlan_type"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["mask_inner_vlan_type"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["key_outer_vlan_type"].to_bytes(2, "little")):
        data.append(x)
    for x in bytearray(rule["mask_outer_vlan_type"].to_bytes(2, "little")):
        data.append(x)
    data.append(rule["key_num_tags"])
    data.append(rule["mask_num_tags"])
    data.append(rule["key_express"])
    data.append(rule["mask_express"])
    for x in bytearray(rule["isMPLS"].to_bytes(1, "little")):
        data.append(x)
    for x in range(5):
        data.append(rule["reserved"][x])
    for x in bytearray(rule["enable"].to_bytes(1, "little")):
        data.append(x)
    transmit_i2c_ICSSFP_CONFIG_MACSEC_RULE(device, netid, len(data), data)


def sfp_ics_send_macsec_map(device, netid, map):
    data = []
    data.append(map["index"])
    for x in bytearray(map["sectag_sci"].to_bytes(8, "little")):
        data.append(x)
    data.append(map["secYIndex"])
    for x in bytearray(map["isControlPacket"].to_bytes(1, "little")):
        data.append(x)
    data.append(map["scIndex"])
    for x in bytearray(map["auxiliary_plcy"].to_bytes(1, "little")):
        data.append(x)
    data.append(map["ruleId"])
    for x in range(5):
        data.append(map["reserved"][x])
    for x in bytearray(map["enable"].to_bytes(1, "little")):
        data.append(x)
    transmit_i2c_ICSSFP_CONFIG_MACSEC_MAP(device, netid, len(data), data)


def sfp_ics_send_macsec_secy(device, netid, secy):
    data = []
    data.append(secy["index"])
    for x in bytearray(secy["controlled_port_enabled"].to_bytes(1, "little")):
        data.append(x)
    data.append(secy["validate_frames"])
    data.append(secy["strip_sectag_icv"])
    data.append(secy["cipher"])
    data.append(secy["confidential_offset"])
    for x in bytearray(secy["icv_includes_da_sa"].to_bytes(1, "little")):
        data.append(x)
    for x in bytearray(secy["replay_protect"].to_bytes(1, "little")):
        data.append(x)
    for x in bytearray(secy["replay_window"].to_bytes(4, "little")):
        data.append(x)
    for x in bytearray(secy["protect_frames"].to_bytes(1, "little")):
        data.append(x)
    data.append(secy["sectag_offset"])
    data.append(secy["sectag_tci"])
    for x in bytearray(secy["mtu"].to_bytes(2, "little")):
        data.append(x)
    for x in range(6):
        data.append(secy["reserved"][x])
    for x in bytearray(secy["enable"].to_bytes(1, "little")):
        data.append(x)
    transmit_i2c_ICSSFP_CONFIG_MACSEC_SECY(device, netid, len(data), data)


def sfp_ics_send_macsec_sc(device, netid, sc):
    data = []
    data.append(sc["index"])
    data.append(sc["secYIndex"])
    for x in bytearray(sc["sci"].to_bytes(8, "little")):
        data.append(x)
    data.append(sc["sa_index0"])
    data.append(sc["sa_index1"])
    for x in bytearray(sc["sa_index0_in_use"].to_bytes(1, "little")):
        data.append(x)
    for x in bytearray(sc["sa_index1_in_use"].to_bytes(1, "little")):
        data.append(x)
    for x in bytearray(sc["enable_auto_rekey"].to_bytes(1, "little")):
        data.append(x)
    for x in bytearray(sc["isActiveSA1"].to_bytes(1, "little")):
        data.append(x)
    for x in range(7):
        data.append(sc["reserved"][x])
    for x in bytearray(sc["enable"].to_bytes(1, "little")):
        data.append(x)
    transmit_i2c_ICSSFP_CONFIG_MACSEC_SC(device, netid, len(data), data)


def sfp_ics_send_macsec_sa(device, netid, sa):
    data = []
    data.append(sa["index"])
    for x in range(32):
        data.append(sa["sak"][x])
    for x in range(16):
        data.append(sa["hashKey"][x])
    for x in range(12):
        data.append(sa["salt"][x])
    for x in bytearray(sa["ssci"].to_bytes(4, "little")):
        data.append(x)
    data.append(sa["AN"])
    for x in bytearray(sa["nextPN"].to_bytes(8, "little")):
        data.append(x)
    for x in range(5):
        data.append(sa["reserved"][x])
    for x in bytearray(sa["enable"].to_bytes(1, "little")):
        data.append(x)
    transmit_i2c_ICSSFP_CONFIG_MACSEC_SA(device, netid, len(data), data)


def sfp_ics_update_macsec(
    device, netid, rx, rule, map, secy, sc, sa, nvm, en, clr, rst
):
    data = []
    byte0 = rule << 0
    byte0 |= map << 1
    byte0 |= secy << 2
    byte0 |= sc << 3
    byte0 |= sa << 4
    byte0 |= rx << 5
    byte0 |= nvm << 6
    byte0 |= en << 7
    data.append(byte0)
    byte1 = clr << 0
    byte1 |= rst << 1
    data.append(byte1)
    transmit_i2c_ICSSFP_UPDATE_MACSEC(device, netid, data)


def sfp_query_verify_macsec_support(device, netid_config):
    netid = get_hwnetid(netid_config)
    # query for MACSec supproted SFP module
    print("Checking for SFP modules...")
    msa = sfp_query_module(device, netid)
    if msa is None:
        print(f"\nNo SFP module found on port {netid_config}... exiting.")
        exit(1)

    # add information to dictionary
    d = {}
    d["Vendor PN"] = msa["vendor_pn"]
    d["Vendor"] = msa["vendor_name"]
    d["Vendor Revision"] = msa["vendor_rev"]
    d["Vendor SN"] = msa["vendor_sn"]
    try:
        man_date = datetime.date(msa["date_year"], msa["date_month"], msa["date_day"])
        x = man_date.strftime("%Y/%m/%d")
    except Exception:
        x = f"{msa['date_year']}/{msa['date_month']}/{msa['date_day']} -- invalid"

    d["Manufacture date"] = f"{x}"
    d["Supports LOS Pin"] = 1 if msa["options"] & 0x02 else 0
    d["Supports LOS Pin (inverted)"] = 1 if (msa["options"] & 0x04) else 0
    d["Supports TX Fault Pin"] = 1 if msa["options"] & 0x08 else 0
    d["Supports TX Disable Pin"] = 1 if msa["options"] & 0x10 else 0
    d["Supports Rate Select Pin"] = 1 if msa["options"] & 0x20 else 0

    d["Supports Soft LOS"] = 1 if msa["enh_options"] & 0x10 else 0
    d["Supports Soft TX Fault"] = 1 if msa["enh_options"] & 0x20 else 0
    d["Supports Soft TX Disable"] = 1 if msa["enh_options"] & 0x40 else 0
    d["Supports Soft Rate Select"] = 1 if msa["enh_options"] & 0x08 else 0
    d["Supports Soft Alarm/Warning Flags"] = 1 if msa["enh_options"] & 0x80 else 0
    d["Supports Diagnostic Monitoring Interface"] = (
        1 if msa["diag_mon_type"] & 0x40 else 0
    )
    if not msa["cc_base_valid"]:
        d["Check Code"] = hex(msa["cc_base"]) + " -- invalid"
    if not msa["cc_ext_valid"]:
        d["Check Code"] = hex(msa["cc_ext"]) + " -- invalid"

    i2c_mdio_address = 0x40
    if not (msa["ics_overrides"] & 0x1):  # MDIO address override
        i2c_mdio_address = msa["ics_mdio_bridge_addr"]

    # print all information
    print(dict_align_str(d))

    # verify macsec supported part
    if d["Vendor PN"] != "SFP-MV2221M-B1":
        print(f"\nNo SFP-MV2221M-B1 module found on port {netid_config}... exiting.")
        exit(1)

    return i2c_mdio_address
