from typing import cast

from pandas import DataFrame


def calculate_validities(
    df: DataFrame, repr_col: str = "representativeness", method_col: str = "method"
) -> DataFrame:
    EXT_COL_NAME: str = "external_validity"
    INT_COL_NAME: str = "internal_validity"
    cols = {EXT_COL_NAME: 0, INT_COL_NAME: 0}

    vd = df[
        (df["design"] == "quasi-experimental") | (df["design"] == "experimental")
    ].copy()
    vd.assign(**cols)
    vd = cast(DataFrame, vd)

    vd[repr_col] = vd[repr_col].fillna("")
    vd[method_col] = vd[method_col].fillna("")
    # needs to check national before subnational, subnational before local
    vd.loc[vd[repr_col].str.contains("national"), EXT_COL_NAME] = 5.0
    vd.loc[vd[repr_col].str.contains("regional"), EXT_COL_NAME] = 4.0
    vd.loc[vd[repr_col].str.contains("subnational"), EXT_COL_NAME] = 3.0
    vd.loc[vd[repr_col].str.contains("local"), EXT_COL_NAME] = 2.0

    # needs to go lowest to highest in case of multiple mentioned approaches
    vd.loc[
        vd[method_col].str.contains("|".join(["OLS", "ordinary.least.square"])),
        INT_COL_NAME,
    ] = 2.0
    vd.loc[
        vd[method_col].str.contains("|".join(["DM", "discontinuity.matching"])),
        INT_COL_NAME,
    ] = 3.0
    vd.loc[
        vd[method_col].str.contains(
            "|".join(["DID", "difference.in.diff", "diff.in.diff", "triple.diff"])
        ),
        INT_COL_NAME,
    ] = 3.0
    vd.loc[
        vd[method_col].str.contains(
            "|".join(["PSM", "propensity.score.matching", "score.matching"])
        ),
        INT_COL_NAME,
    ] = 3.5
    vd.loc[
        vd[method_col].str.contains("|".join(["IV", "instrumental.variable"])),
        INT_COL_NAME,
    ] = 4.0
    vd.loc[
        vd[method_col].str.contains("|".join(["RD", "regression.discontinuity"])),
        INT_COL_NAME,
    ] = 4.5
    vd.loc[vd[method_col].str.contains("RCT"), INT_COL_NAME] = 5.0

    return vd


if __name__ == "__main__":
    import os
    import sys
    from io import StringIO
    from pathlib import Path

    import load_data

    if len(sys.argv) == 2:
        df = load_data.from_yml(Path(sys.argv[1]))
    else:
        df = load_data.from_yml()

    df = calculate_validities(df)
    output = StringIO()
    df.to_csv(output)
    output.seek(0)
    try:
        print(output.read())
    except BrokenPipeError:
        devnull = os.open(os.devnull, os.O_WRONLY)
        os.dup2(devnull, sys.stdout.fileno())