Source code for snews_cs.snews_sql

import inspect
import os
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import pandas as pd

from . import cs_utils
from .database import Database

# Get the directory of the script
[docs] current_script_directory = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe())) )
# Go back one level to the parent directory
[docs] parent_directory = os.path.join(current_script_directory, os.pardir)
[docs] class Storage: """ Class for interacting with the SNEWS SQL database. Parameters ---------- env : `str`, optional Path to env file, defaults to './etc/test-config.env' drop_db : `bool`, optional drops all items in the DB every time Storage is initialized, defaults to False """ def __init__(self, env=None, drop_db=True): cs_utils.set_env(env)
[docs] self.mgs_expiration = int(os.getenv("MSG_EXPIRATION"))
[docs] self.coinc_threshold = int(os.getenv("COINCIDENCE_THRESHOLD"))
[docs] self.db_path = os.path.join(parent_directory, "snews_cs.db")
[docs] self.db = Database(db_file_path=self.db_path)
[docs] self.conn = self.db.connection
[docs] self.cursor = self.db.cursor
if drop_db: self.db.drop_tables( table_names=[ "all_mgs", "sig_tier_archive", "time_tier_archive", "coincidence_tier_archive", "coincidence_tier_alerts", ] ) self.db.initialize_database(sql_schema_path=Path(__file__).parent / "db_schema.sql")
[docs] def insert_mgs(self, mgs, tier): """ Inserts a message into the all_mgs table. Parameters ---------- mgs : `dict` dictionary of the SNEWS message """ # to sent time datetime string and expiration datetime string expiration = datetime.fromisoformat(mgs["received_time"]) + timedelta(hours=48) expiration = expiration.isoformat() # MK: proposed change # expiration = np.datetime64(mgs['received_time'][0]) + np.timedelta64(48, 'h') # expiration = np.datetime_as_string(expiration, unit='ns') if tier == "SIG": self.cursor.execute( """INSERT INTO all_mgs VALUES (?, ?, ?, ?, ?)""", (mgs["id"], mgs["received_time"], "SIG", str(mgs), expiration), ) self.cursor.execute( """INSERT INTO sig_tier_archive VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( mgs["id"], mgs["schema_version"], mgs["detector_name"], str(mgs["p_vals"]), mgs["t_bin_width_sec"], mgs["sent_time_utc"], mgs["machine_time_utc"], str(mgs["meta"]), expiration, ), ) self.conn.commit() elif tier == "TIME": self.cursor.execute( """INSERT INTO all_mgs VALUES (?, ?, ?, ?, ?)""", (mgs["id"], mgs["received_time"], "TIME", str(mgs), expiration), ) self.cursor.execute( """INSERT INTO time_tier_archive VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( mgs["id"], mgs["schema_version"], mgs["detector_name"], mgs["p_val"], mgs["t_bin_width_sec"], str(mgs["timing_series"]), mgs["sent_time_utc"], mgs["machine_time_utc"], str(mgs["meta"]), expiration, ), ) self.conn.commit() elif tier == "COINC": self.cursor.execute( """INSERT INTO all_mgs VALUES (?, ?, ?, ?, ?)""", (mgs["id"], mgs["received_time"], "COINC", str(mgs), expiration), ) self.cursor.execute( """INSERT INTO coincidence_tier_archive VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( mgs["id"], mgs["schema_version"], mgs["detector_name"], mgs["p_val"], mgs["neutrino_time_utc"], mgs["sent_time_utc"], mgs["machine_time_utc"], str(mgs["meta"]), expiration, ), ) self.conn.commit()
[docs] def insert_alert(self, alert, tier): if tier == "COINC": self.cursor.execute( """INSERT INTO coincidence_tier_alerts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( alert["id"], alert["alert_type"], alert["server_tag"], alert["false_alarm_prob"], str(alert["detector_names"]), alert["sent_time_utc"], str(alert["p_vals"]), str(alert["neutrino_times"]), alert["p_vals_average"], alert["sub_list_number"], ), ) self.conn.commit() elif tier == "SIG": pass elif tier == "TIME": pass
[docs] def drop_expired(self): """ Drops all expired messages from the all_mgs table. """ self.cursor.execute( """DELETE FROM all_mgs WHERE expiration < ?""", (datetime.now().isoformat(),), ) self.cursor.execute( """DELETE FROM sig_tier_archive WHERE expiration < ?""", (datetime.now().isoformat(),), ) self.cursor.execute( """DELETE FROM time_tier_archive WHERE expiration < ?""", (datetime.now().isoformat(),), ) self.cursor.execute( """DELETE FROM coincidence_tier_archive WHERE expiration < ?""", (datetime.now().isoformat(),), ) self.conn.commit()
[docs] def get_all_messages(self, sort_order="ASC"): """ Returns all messages in the all_mgs table. """ self.cursor.execute( """SELECT * FROM all_mgs ORDER BY received_time {}""".format(sort_order) ) return self.cursor.fetchall()
[docs] def get_all_coinc_alerts(self, sort_order="ASC"): """ Returns all messages in the all_mgs table. """ self.cursor.execute( """SELECT * FROM coincidence_tier_alerts ORDER BY sent_time_utc {}""".format( sort_order ) ) return self.cursor.fetchall()
[docs] def get_all_sig_alerts(self, sort_order="ASC"): pass
[docs] def get_all_time_alerts(self, sort_order="ASC"): pass
[docs] def get_all_sig_messages(self, sort_order="ASC"): """ Returns all messages in the all_mgs table. """ self.cursor.execute( """SELECT * FROM sig_tier_archive ORDER BY sent_time_utc {}""".format( sort_order ) ) table = self.cursor.fetchall() return table
[docs] def get_all_time_messages(self, sort_order="ASC"): """ Returns all messages in the all_mgs table. """ self.cursor.execute( """SELECT * FROM time_tier_archive ORDER BY sent_time_utc {}""".format( sort_order ) ) table = self.cursor.fetchall() return table
[docs] def get_all_coinc_messages(self, sort_order="ASC"): """ Returns all messages in the all_mgs table. """ self.cursor.execute( """SELECT * FROM coincidence_tier_archive ORDER BY sent_time_utc {}""".format( sort_order ) ) table = self.cursor.fetchall() return table
[docs] def retract_message(self, message_id, tier): """ Retracts a message from the all_mgs table. Parameters ---------- message_id : `str` unique id for each message """ self.cursor.execute( """DELETE FROM all_mgs WHERE message_id = ?""", (message_id,) ) if tier == "SIG": self.cursor.execute( """DELETE FROM sig_tier_archive WHERE message_id = ?""", (message_id,) ) elif tier == "TIME": self.cursor.execute( """DELETE FROM time_tier_archive WHERE message_id = ?""", (message_id,) ) elif tier == "COINC": self.cursor.execute( """DELETE FROM coincidence_tier_archive WHERE message_id = ?""", (message_id,), ) self.conn.commit()
[docs] def update_message(self, message, tier): """ Updates a message in the all_mgs table and corresponding tier tabe. """ self.cursor.execute( """UPDATE all_mgs SET message = ? WHERE message_id = ?""", (str(message), message["id"]), ) if tier == "SIG": # update all columns except _id self.cursor.execute( """UPDATE sig_tier_archive SET schema_version = ?, detector_name = ?, p_vals = ?, t_bin_width_sec = ?, sent_time_utc = ?, machine_time_utc = ?, meta = ? WHERE message_id = ?""", ( message["schema_version"], message["detector_name"], str(message["p_vals"]), message["t_bin_width_sec"], message["sent_time_utc"], message["machine_time_utc"], str(message["meta"]), message["id"], ), ) elif tier == "TIME": # update all columns except _id self.cursor.execute( """UPDATE time_tier_archive SET schema_version = ?, detector_name = ?, p_val = ?, t_bin_width_sec = ?, timing_series = ?, sent_time_utc = ?, machine_time_utc = ?, meta = ? WHERE message_id = ?""", ( message["schema_version"], message["detector_name"], message["p_val"], message["t_bin_width_sec"], str(message["timing_series"]), message["sent_time_utc"], message["machine_time_utc"], str(message["meta"]), message["id"], ), ) elif tier == "COINC": # update all columns except _id self.cursor.execute( """UPDATE coincidence_tier_archive SET schema_version = ?, detector_name = ?, p_val = ?, neutrino_time_utc = ?, sent_time_utc = ?, machine_time_utc = ?, meta = ? WHERE message_id = ?""", ( message["schema_version"], message["detector_name"], message["p_val"], message["neutrino_time_utc"], message["sent_time_utc"], message["machine_time_utc"], str(message["meta"]), message["id"], ), ) self.conn.commit()
[docs] def show_tables(self): """ Returns all tables in the SQL database. """ self.cursor.execute( """SELECT name FROM sqlite_master WHERE type='table' ORDER BY name""" ) table = self.cursor.fetchall() return table
[docs] def get_table_schema(self, table_name): """ Returns the schema for a given table. """ self.cursor.execute("""PRAGMA table_info({})""".format(table_name)) schema = self.cursor.fetchall() return schema
[docs] def insert_coinc_cache(self, cache): """ Inserts coincidence cache dataframe into the coincidence_tier_archive table. Parameters ---------- cache : dataframe dictionary of the SNEWS message """ # to sent time datetime string and expiration datetime string expiration = np.datetime64(cache["sent_time_utc"][0]) + np.timedelta64(48, "h") expiration = np.datetime_as_string(expiration, unit="ns") # expiration = datetime.fromisoformat(cache['sent_time'][0]) + timedelta(hours=48) # expiration = expiration.isoformat() # coincidence_tier_archive is not empty delete all rows self.cursor.execute("""DELETE FROM coincidence_tier_archive""") # insert dataframe into table insert_query = """ INSERT INTO coincidence_tier_archive ( message_id, schema_version, detector_name, p_val, neutrino_time_utc, sent_time_utc, machine_time_utc, meta, expiration ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """ try: for index, row in cache.iterrows(): self.cursor.execute( insert_query, ( row["id"], row["schema_version"], row["detector_name"], row["p_val"], row["neutrino_time_utc"], row["sent_time_utc"], row["machine_time_utc"], str(row["meta"]), expiration, ), ) self.conn.commit() except Exception as e: # Log the error and rollback the transaction if needed print(f"Error inserting data: {e}") self.conn.rollback()
[docs] def retrieve_coinc_cache(self): """ Returns coincidence cache dataframe from the coincidence_tier_archive table and saves it as a dataframe. """ self.cursor.execute("""SELECT * FROM coincidence_tier_archive""") table = self.cursor.fetchall() return pd.DataFrame( table, columns=[ "message_id", "schema_version", "detector_name", "p_val", "neutrino_time_utc", "sent_time_utc", "machine_time_utc", "meta", "expiration", ], )