Add csv typer options to cli

This commit is contained in:
Marty Oehme 2025-08-12 12:03:47 +02:00
parent 1c7738f1c0
commit 01d992fd4c
Signed by: Marty
GPG key ID: 4E535BC19C61886E

126
main.py
View file

@ -10,8 +10,8 @@ import typer
BASE_URL = "https://www.nightjet.com"
BASE_DIR = "out"
CSV_LOWEST_FILE = f"{BASE_DIR}/lowest.csv"
CSV_ALL_PRICES_PATTERN = f"{BASE_DIR}/%%DATE%%_all_prices.csv"
CSV_LOWEST_FILE = "lowest.csv"
CSV_ALL_PRICES_PATTERN = "all_prices_%%DATE%%.csv"
NOTIFICATION_CHANNEL = "nightjet-price-notifier"
START_STATION = "8096003" # BerlinHBF
@ -193,22 +193,35 @@ def get_lowest_price(prices: list[Price]) -> Price:
return lowest
def dump_all_prices_to_csv(prices: list[Price]) -> None:
fname = CSV_ALL_PRICES_PATTERN.replace(
"%%DATE%%", str(int(datetime.now().timestamp()))
def dump_all_prices_to_csv(prices: list[Price], fpath: Path) -> None:
fstr = str(fpath)
fpath_replaced = Path(
fstr.replace("%%DATE%%", str(int(datetime.now().timestamp())))
)
with open(fname, "w") as f:
with open(fpath_replaced, "w") as f:
writer = csv.writer(f)
writer.writerow(["id", "price", "name"])
writer.writerows([[price.id, price.price, price.name] for price in prices])
writer.writerow(["id", "price", "ts_from", "ts_to", "name"])
writer.writerows(
[
[
price.id,
price.price,
price.dt_from.timestamp(),
price.dt_to.timestamp(),
price.name,
]
for price in prices
]
)
dprint(f"Dumped current query snapshot into: {fpath_replaced}.")
def add_to_csv(price: Price) -> None:
if not Path(CSV_LOWEST_FILE).is_file():
with open(CSV_LOWEST_FILE, "w") as f:
def add_to_csv(price: Price, file: Path) -> None:
if not file.is_file():
with open(file, "w") as f:
csv.writer(f).writerow(["id", "price", "ts_from", "ts_to", "name"])
with open(CSV_LOWEST_FILE, "a") as f:
with open(file, "a") as f:
csv.writer(f).writerow(
[
price.id,
@ -220,11 +233,11 @@ def add_to_csv(price: Price) -> None:
)
def get_last_price_from_csv() -> Price | None:
if not Path(CSV_LOWEST_FILE).is_file():
def get_last_price_from_csv(file: Path) -> Price | None:
if not file.is_file():
return
with open(CSV_LOWEST_FILE) as f:
with open(file) as f:
last = next(reversed(list(csv.reader(f))))
return Price(
id=last[0],
@ -247,27 +260,71 @@ def notify_user(previous: Price, new: Price, channel: str) -> None:
)
def main(start_station: int, end_station: int, travel_date: datetime):
Path(BASE_DIR).mkdir(exist_ok=True, parents=True)
print(start_station, end_station, travel_date)
# return
def query(start_station: int, end_station: int, travel_date: datetime) -> list[Price]:
token = request_init_token()
connections = request_connections(token, start_station, end_station, travel_date)
booking_requests = connection_data_to_booking_requests(connections)
bookings = [request_bookings(token, req) for req in booking_requests]
prices = extract_prices(bookings)
return prices
## CLI
app = typer.Typer()
@app.command()
def main(
start_station: int = typer.Option(
START_STATION, help="Departure station number. (default: Berlin Hbf)"
),
end_station: int = typer.Option(
END_STATION, help="Destination station number. (default: Paris Est)"
),
travel_date: str = typer.Option(help="Travel day to search from. (YYYY-MM-DD)"),
notification_channel: str = typer.Option(
NOTIFICATION_CHANNEL, help="ntfy channel to inform user on."
),
base_output_directory: Path = typer.Option(
Path(BASE_DIR), help="Directory in which to output all result files."
),
lowest_prices_filename: str = typer.Option(
CSV_LOWEST_FILE, help="Filename for collecting lowest found prices."
),
price_snapshot_pattern: str = typer.Option(
CSV_ALL_PRICES_PATTERN,
help="Filename pattern for saving all prices of each query. Takes %%DATE%% as pattern to replace with current unix timestamp.",
),
dump_price_snapshot: bool = typer.Option(
True, help="Dump _all_ queried prices into a timestamped csv file."
),
):
base_output_directory.mkdir(exist_ok=True, parents=True)
lowest_prices_path = base_output_directory.joinpath(lowest_prices_filename)
price_snapshot_path = base_output_directory.joinpath(price_snapshot_pattern)
try:
date_obj = datetime.strptime(travel_date, "%Y-%m-%d")
except ValueError:
typer.echo(f"Invalid date format: {travel_date}. Use YYYY-MM-DD", err=True)
raise typer.Exit(1)
prices = query(
start_station=start_station, end_station=end_station, travel_date=date_obj
)
# create a snapshot of all current prices
dump_all_prices_to_csv(prices)
if dump_price_snapshot:
dump_all_prices_to_csv(prices, price_snapshot_path)
# extract the lowest and the last lowest price
new = get_lowest_price(prices)
previous = get_last_price_from_csv()
previous = get_last_price_from_csv(lowest_prices_path)
# if the price changed, add it to lowest prices
if not previous or new.price != previous.price:
dprint(f"PRICE CHANGE. {previous} -> {new}")
add_to_csv(new, lowest_prices_path)
notify_user(
previous
or Price(
@ -278,31 +335,8 @@ def main(start_station: int, end_station: int, travel_date: datetime):
datetime.fromtimestamp(0),
),
new,
NOTIFICATION_CHANNEL,
notification_channel,
)
add_to_csv(new)
## CLI
app = typer.Typer()
@app.command()
def search(
start_station: int = typer.Option(
START_STATION, help="Departure station number. (default: Berlin Hbf)"
),
end_station: int = typer.Option(
END_STATION, help="Destination station number. (default: Paris Est)"
),
travel_date: str = typer.Option(help="Travel day to search from. (YYYY-MM-DD)"),
):
try:
date_obj = datetime.strptime(travel_date, "%Y-%m-%d")
except ValueError:
typer.echo(f"Invalid date format: {travel_date}. Use YYYY-MM-DD", err=True)
raise typer.Exit(1)
main(start_station=start_station, end_station=end_station, travel_date=date_obj)
if __name__ == "__main__":