diff --git a/src/calculate_validities.py b/src/calculate_validities.py new file mode 100644 index 0000000..e3955ce --- /dev/null +++ b/src/calculate_validities.py @@ -0,0 +1,71 @@ +from typing import cast +from pandas import DataFrame + + +def calculate_validities( + df: DataFrame, repr_col: str = "representativeness", method_col: str = "method" +) -> DataFrame: + EXT_COL_NAME: str = "external_validity" + INT_COL_NAME: str = "internal_validity" + vd = df[(df["design"] == "quasi-experimental") | (df["design"] == "experimental")] + + vd[EXT_COL_NAME] = 0 + vd[INT_COL_NAME] = 0 + vd = cast(DataFrame, vd) + + vd[repr_col] = vd[repr_col].fillna("") + # needs to check national before subnational, subnational before local + vd.loc[vd[repr_col].str.contains("national"), EXT_COL_NAME] = 5.0 + vd.loc[vd[repr_col].str.contains("regional"), EXT_COL_NAME] = 4.0 + vd.loc[vd[repr_col].str.contains("subnational"), EXT_COL_NAME] = 3.0 + vd.loc[vd[repr_col].str.contains("local"), EXT_COL_NAME] = 2.0 + + vd[method_col] = vd[method_col].fillna("") + # needs to go lowest to highest in case of multiple mentioned approaches + vd.loc[ + vd[method_col].str.contains("|".join(["OLS", "ordinary.least.square"])), + INT_COL_NAME, + ] = 2.0 + vd.loc[ + vd[method_col].str.contains("|".join(["DM", "discontinuity.matching"])), + INT_COL_NAME, + ] = 3.0 + vd.loc[ + vd[method_col].str.contains( + "|".join(["DID", "difference.in.diff", "diff.in.diff", "triple.diff"]) + ), + INT_COL_NAME, + ] = 3.0 + vd.loc[ + vd[method_col].str.contains( + "|".join(["PSM", "propensity.score.matching", "score.matching"]) + ), + INT_COL_NAME, + ] = 3.5 + vd.loc[ + vd[method_col].str.contains("|".join(["IV", "instrumental.variable"])), + INT_COL_NAME, + ] = 4.0 + vd.loc[ + vd[method_col].str.contains("|".join(["RD", "regression.discontinuity"])), + INT_COL_NAME, + ] = 4.5 + vd.loc[vd[method_col].str.contains("RCT"), INT_COL_NAME] = 5.0 + + return vd + +if __name__ == "__main__": + import sys + import load_data + from pathlib import Path + from io import StringIO + if len(sys.argv) == 2: + df = load_data.from_yml(Path(sys.argv[1])) + else: + df = load_data.from_yml() + + df = calculate_validities(df) + output = StringIO() + df.to_csv(output) + output.seek(0) + print(output.read())