refactor(code): Split validity calc and matrix extract

Validity calculation belongs to the modelling, so we put it into the
validity module.

Extracting our matrix is a processing step so we made its own matrix
module and put it in their.
Should hopefully provide better separation of concerns going forward.
This commit is contained in:
Marty Oehme 2024-02-16 11:25:19 +01:00
parent 8333bbe9be
commit fac7d4c86a
Signed by: Marty
GPG key ID: EDBF2ED917B2EF6A
5 changed files with 103 additions and 89 deletions

View file

@ -41,7 +41,7 @@ cmd = "nvim"
[tool.poe.tasks.extract] [tool.poe.tasks.extract]
help = "Extract the csv data from raw yaml files" help = "Extract the csv data from raw yaml files"
shell = """ shell = """
python src/prep_data.py > 02-data/processed/extracted.csv python src/matrix.py > 02-data/processed/extracted.csv
""" """
[tool.poe.tasks.milestone] [tool.poe.tasks.milestone]
help = "Extract, render, commit and version a finished artifact" help = "Extract, render, commit and version a finished artifact"

View file

@ -613,9 +613,9 @@ to better identify areas of strong analytical lenses or areas of more limited an
```{python} ```{python}
#| label: fig-validity #| label: fig-validity
from src import prep_data from src.model import validity
validities = prep_data.calculate_validities(by_intervention) validities = validity.calculate(by_intervention)
validities["identifier"] = validities["author"].str.replace(r',.*$', '', regex=True) + " (" + validities["year"].astype(str) + ")" validities["identifier"] = validities["author"].str.replace(r',.*$', '', regex=True) + " (" + validities["year"].astype(str) + ")"
g = sns.PairGrid(validities[["internal_validity", "external_validity", "identifier"]].drop_duplicates(subset="identifier"), g = sns.PairGrid(validities[["internal_validity", "external_validity", "identifier"]].drop_duplicates(subset="identifier"),

38
src/matrix.py Normal file
View file

@ -0,0 +1,38 @@
from io import StringIO
from pathlib import Path
from pandas import DataFrame
try:
from src.model import validity # for quarto document scripts
except ModuleNotFoundError:
from model import validity # for directly running the package
def extract(df: DataFrame, file: Path | StringIO) -> None:
(
validity.calculate(df)
.drop(labels=["observation"], axis="columns")
.to_csv(file, index=False, encoding="utf-8")
)
if __name__ == "__main__":
import os
import sys
import load_data
if len(sys.argv) == 2:
df = load_data.from_yml(Path(sys.argv[1]))
else:
df = load_data.from_yml()
output = StringIO()
extract(df, output)
output.seek(0)
try:
print(output.read())
except BrokenPipeError:
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())

View file

@ -1,9 +1,8 @@
import math import math
from typing import cast
from pandas import DataFrame from pandas import DataFrame
from src import prep_data
def _binned_strength(strength: float) -> str: def _binned_strength(strength: float) -> str:
if strength < 3.0: if strength < 3.0:
@ -28,15 +27,74 @@ def _combined_validities(
return r"\-" return r"\-"
def calculate(
df: DataFrame, repr_col: str = "representativeness", method_col: str = "method"
) -> DataFrame:
EXT_COL_NAME: str = "external_validity"
INT_COL_NAME: str = "internal_validity"
cols = {EXT_COL_NAME: 0.0, INT_COL_NAME: 0.0}
vd = df[
(df["design"] == "quasi-experimental") | (df["design"] == "experimental")
].copy()
vd.assign(**cols)
vd = cast(DataFrame, vd)
vd[repr_col] = vd[repr_col].fillna("")
vd[method_col] = vd[method_col].fillna("")
# needs to check national before subnational and census, subnational before local
vd.loc[
vd[repr_col].str.contains("|".join(["national", "regional"])), EXT_COL_NAME
] = 4.0
vd.loc[vd[repr_col].str.contains("census"), EXT_COL_NAME] = 5.0
vd.loc[vd[repr_col].str.contains("subnational"), EXT_COL_NAME] = 3.0
vd.loc[vd[repr_col].str.contains("local"), EXT_COL_NAME] = 2.0
# needs to go lowest to highest in case of multiple mentioned approaches
vd.loc[
vd[method_col].str.contains(
"|".join(["OLS", "ordinary.least.square", "logistic.regression"])
),
INT_COL_NAME,
] = 2.0
vd.loc[
vd[method_col].str.contains("|".join(["DM", "discontinuity.matching"])),
INT_COL_NAME,
] = 3.0
vd.loc[
vd[method_col].str.contains(
"|".join(["DID", "difference.in.diff", "diff.in.diff", "triple.diff"])
),
INT_COL_NAME,
] = 3.0
vd.loc[
vd[method_col].str.contains(
"|".join(["PSM", "propensity.score.matching", "score.matching"])
),
INT_COL_NAME,
] = 3.5
vd.loc[
vd[method_col].str.contains("|".join(["IV", "instrumental.variable"])),
INT_COL_NAME,
] = 4.0
vd.loc[
vd[method_col].str.contains("|".join(["RD", "regression.discontinuity"])),
INT_COL_NAME,
] = 4.5
vd.loc[vd[method_col].str.contains("RCT"), INT_COL_NAME] = 5.0
return vd
def add_to_findings( def add_to_findings(
findings_df: DataFrame, studies_by_intervention: DataFrame findings_df: DataFrame, studies_by_intervention: DataFrame
) -> DataFrame: ) -> DataFrame:
valid_subset = ( valid_subset = (
prep_data.calculate_validities(studies_by_intervention)[ calculate(studies_by_intervention)[
["internal_validity", "external_validity", "citation"] ["internal_validity", "external_validity", "citation"]
] ]
.fillna(1.0) .fillna(1.0)
.drop_duplicates(subset=["citation"]) # type: ignore .drop_duplicates(subset=["citation"]) # type: ignore
.sort_values("internal_validity") .sort_values("internal_validity")
) )

View file

@ -1,82 +0,0 @@
from typing import cast
from pandas import DataFrame
def calculate_validities(
df: DataFrame, repr_col: str = "representativeness", method_col: str = "method"
) -> DataFrame:
EXT_COL_NAME: str = "external_validity"
INT_COL_NAME: str = "internal_validity"
cols = {EXT_COL_NAME: 0.0, INT_COL_NAME: 0.0}
vd = df[
(df["design"] == "quasi-experimental") | (df["design"] == "experimental")
].copy()
vd.assign(**cols)
vd = cast(DataFrame, vd)
vd[repr_col] = vd[repr_col].fillna("")
vd[method_col] = vd[method_col].fillna("")
# needs to check national before subnational and census, subnational before local
vd.loc[vd[repr_col].str.contains("|".join(["national", "regional"])), EXT_COL_NAME] = 4.0
vd.loc[vd[repr_col].str.contains("census"), EXT_COL_NAME] = 5.0
vd.loc[vd[repr_col].str.contains("subnational"), EXT_COL_NAME] = 3.0
vd.loc[vd[repr_col].str.contains("local"), EXT_COL_NAME] = 2.0
# needs to go lowest to highest in case of multiple mentioned approaches
vd.loc[
vd[method_col].str.contains("|".join(["OLS", "ordinary.least.square", "logistic.regression"])),
INT_COL_NAME,
] = 2.0
vd.loc[
vd[method_col].str.contains("|".join(["DM", "discontinuity.matching"])),
INT_COL_NAME,
] = 3.0
vd.loc[
vd[method_col].str.contains(
"|".join(["DID", "difference.in.diff", "diff.in.diff", "triple.diff"])
),
INT_COL_NAME,
] = 3.0
vd.loc[
vd[method_col].str.contains(
"|".join(["PSM", "propensity.score.matching", "score.matching"])
),
INT_COL_NAME,
] = 3.5
vd.loc[
vd[method_col].str.contains("|".join(["IV", "instrumental.variable"])),
INT_COL_NAME,
] = 4.0
vd.loc[
vd[method_col].str.contains("|".join(["RD", "regression.discontinuity"])),
INT_COL_NAME,
] = 4.5
vd.loc[vd[method_col].str.contains("RCT"), INT_COL_NAME] = 5.0
return vd
if __name__ == "__main__":
import os
import sys
from io import StringIO
from pathlib import Path
import load_data
if len(sys.argv) == 2:
df = load_data.from_yml(Path(sys.argv[1]))
else:
df = load_data.from_yml()
df = calculate_validities(df)
output = StringIO()
df.to_csv(output)
output.seek(0)
try:
print(output.read())
except BrokenPipeError:
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())