#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import pickle
from collections.abc import Iterable
from typing import Any, Dict
from aepsych.config import Config
from aepsych.version import __version__
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Integer,
PickleType,
String,
)
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
logger = logging.getLogger()
Base = declarative_base()
[docs]class DBMasterTable(Base):
"""
Master table to keep track of all experiments and unique keys associated with the experiment
"""
__tablename__ = "master"
unique_id = Column(Integer, primary_key=True, autoincrement=True)
experiment_name = Column(String(256), nullable=True)
experiment_description = Column(String(2048), nullable=True)
experiment_id = Column(String(10))
participant_id = Column(String(50))
extra_metadata = Column(String(4096)) # JSON-formatted metadata
children_replay = relationship("DbReplayTable", back_populates="parent")
children_strat = relationship("DbStratTable", back_populates="parent")
children_config = relationship("DbConfigTable", back_populates="parent")
children_raw = relationship("DbRawTable", back_populates="parent")
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DBMasterTable":
"""Create a DBMasterTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DBMasterTable: A DBMasterTable object.
"""
this = DBMasterTable()
this.unique_id = row["unique_id"]
this.experiment_name = row["experiment_name"]
this.experiment_description = row["experiment_description"]
this.experiment_id = row["experiment_id"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DBMasterTable object.
Returns:
str: A string representation of the DBMasterTable object.
"""
return (
f"<DBMasterTable(unique_id={self.unique_id})"
f", experiment_name={self.experiment_name}, "
f"experiment_description={self.experiment_description}, "
f"experiment_id={self.experiment_id})>"
)
[docs] @staticmethod
def update(engine: Engine) -> None:
"""Update the master table schema to include extra_metadata and participant_id columns.
Args:
engine (Engine): The sqlalchemy engine.
"""
logger.info("DBMasterTable : update called")
if not DBMasterTable._has_column(engine, "extra_metadata"):
DBMasterTable._add_column(engine, "extra_metadata")
if not DBMasterTable._has_column(engine, "participant_id"):
DBMasterTable._add_column(engine, "participant_id")
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the master table schema requires an update.
Args:
engine (Engine): The sqlalchemy engine.
"""
return not DBMasterTable._has_column(
engine, "extra_metadata"
) or not DBMasterTable._has_column(engine, "participant_id")
@staticmethod
def _has_column(engine: Engine, column: str) -> bool:
"""Check if the master table has a column.
Args:
engine (Engine): The sqlalchemy engine.
column (str): The column name.
Returns:
bool: True if the column exists, False otherwise.
"""
result = engine.execute(
"SELECT COUNT(*) FROM pragma_table_info('master') WHERE name='{0}'".format(
column
)
)
rows = result.fetchall()
count = rows[0][0]
return count != 0
@staticmethod
def _add_column(engine: Engine, column: str) -> None:
"""Add a column to the master table.
Args:
engine (Engine): The sqlalchemy engine.
column (str): The column name.
"""
try:
result = engine.execute(
"SELECT COUNT(*) FROM pragma_table_info('master') WHERE name='{0}'".format(
column
)
)
rows = result.fetchall()
count = rows[0][0]
if 0 == count:
logger.debug(
"Altering the master table to add the {0} column".format(column)
)
engine.execute(
"ALTER TABLE master ADD COLUMN {0} VARCHAR".format(column)
)
engine.commit()
except Exception as e:
logger.debug(f"Column already exists, no need to alter. [{e}]")
@staticmethod
def _update_column(engine: Engine, column: str, spec: str) -> None:
"""Update column with a new spec.
Args:
engine (Engine): The sqlalchemy engine.
column (str): The column name.
spec (str): The new column spec.
"""
logger.debug(f"Altering the master table column: {column} to this spec {spec}")
engine.execute(f"ALTER TABLE master MODIFY {column} {spec}")
engine.commit()
[docs]class DbReplayTable(Base):
__tablename__ = "replay_data"
use_extra_info = False
unique_id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(DateTime)
message_type = Column(String(64))
# specify the pickler to allow backwards compatibility between 3.7 and 3.8
message_contents = Column(PickleType(pickler=pickle))
extra_info = Column(PickleType(pickler=pickle))
master_table_id = Column(Integer, ForeignKey("master.unique_id"))
parent = relationship("DBMasterTable", back_populates="children_replay")
__mapper_args__ = {}
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DbReplayTable":
"""Create a DbReplayTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DbReplayTable: A DbReplayTable object.
"""
this = DbReplayTable()
this.unique_id = row["unique_id"]
this.timestamp = row["timestamp"]
this.message_type = row["message_type"]
this.message_contents = row["message_contents"]
this.master_table_id = row["master_table_id"]
if "extra_info" in row:
this.extra_info = row["extra_info"]
else:
this.extra_info = None
this.strat = row["strat"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DbReplayTable object.
Returns:
str: A string representation of the DbReplayTable object.
"""
return (
f"<DbReplayTable(unique_id={self.unique_id})"
f", timestamp={self.timestamp}, "
f"message_type={self.message_type}"
f", master_table_id={self.master_table_id})>"
)
@staticmethod
def _has_extra_info(engine: Engine) -> bool:
"""Check if the replay_data table has an extra_info column.
Args:
engine (Engine): The sqlalchemy engine.
Returns:
bool: True if the extra_info column exists, False otherwise.
"""
result = engine.execute(
"SELECT COUNT(*) FROM pragma_table_info('replay_data') WHERE name='extra_info'"
)
rows = result.fetchall()
count = rows[0][0]
return count != 0
@staticmethod
def _configs_require_conversion(engine: Engine) -> bool:
"""Check if the replay_data table has any old configs that need to be converted.
Args:
engine (Engine): The sqlalchemy engine.
Returns:
bool: True if any old configs need to be converted, False otherwise.
"""
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
results = session.query(DbReplayTable).all()
for result in results:
if result.message_contents["type"] == "setup":
config_str = result.message_contents["message"]["config_str"]
config = Config(config_str=config_str)
if config.version < __version__:
return True # assume that if any config needs to be refactored, all of them do
return False
[docs] @staticmethod
def update(engine: Engine) -> None:
"""Update the replay_data table schema to include an extra_info column and convert old configs.
Args:
engine (Engine): The sqlalchemy engine.
"""
logger.info("DbReplayTable : update called")
if not DbReplayTable._has_extra_info(engine):
DbReplayTable._add_extra_info(engine)
if DbReplayTable._configs_require_conversion(engine):
DbReplayTable._convert_configs(engine)
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the replay_data table schema requires an update.
Args:
engine (Engine): The sqlalchemy engine.
"""
return not DbReplayTable._has_extra_info(
engine
) or DbReplayTable._configs_require_conversion(engine)
@staticmethod
def _add_extra_info(engine: Engine) -> None:
"""Add an extra_info column to the replay_data table.
Args:
engine (Engine): The sqlalchemy engine.
"""
try:
result = engine.execute(
"SELECT COUNT(*) FROM pragma_table_info('replay_data') WHERE name='extra_info'"
)
rows = result.fetchall()
count = rows[0][0]
if 0 == count:
logger.debug(
"Altering the replay_data table to add the extra_info column"
)
engine.execute("ALTER TABLE replay_data ADD COLUMN extra_info BLOB")
engine.commit()
except Exception as e:
logger.debug(f"Column already exists, no need to alter. [{e}]")
@staticmethod
def _convert_configs(engine: Engine) -> None:
"""Convert old configs to the latest version.
Args:
engine (Engine): The sqlalchemy engine.
"""
Session = sessionmaker(bind=engine)
session = Session()
results = session.query(DbReplayTable).all()
for result in results:
if result.message_contents["type"] == "setup":
config_str = result.message_contents["message"]["config_str"]
config = Config(config_str=config_str)
if config.version < __version__:
config.convert_to_latest()
new_str = str(config)
new_message = {"type": "setup", "message": {"config_str": new_str}}
if "version" in result.message_contents:
new_message["version"] = result.message_contents["version"]
result.message_contents = new_message
session.commit()
logger.info("DbReplayTable : updated old configs.")
[docs]class DbStratTable(Base):
__tablename__ = "strat_data"
unique_id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(DateTime)
strat = Column(PickleType(pickler=pickle))
master_table_id = Column(Integer, ForeignKey("master.unique_id"))
parent = relationship("DBMasterTable", back_populates="children_strat")
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DbStratTable":
"""Create a DbStratTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DbStratTable: A DbStratTable object.
"""
this = DbStratTable()
this.unique_id = row["unique_id"]
this.timestamp = row["timestamp"]
this.strat = row["strat"]
this.master_table_id = row["master_table_id"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DbStratTable object.
Returns:
str: A string representation of the DbStratTable object.
"""
return (
f"<DbStratTable(unique_id={self.unique_id})"
f", timestamp={self.timestamp} "
f", master_table_id={self.master_table_id})>"
)
[docs] @staticmethod
def update(engine: Engine) -> None:
"""Update the strat_data table schema.
Args:
engine (Engine): The sqlalchemy engine.
"""
logger.info("DbStratTable : update called")
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the strat_data table schema requires an update.(always False)
Args:
engine (Engine): The sqlalchemy engine.
"""
return False
[docs]class DbConfigTable(Base):
__tablename__ = "config_data"
unique_id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(DateTime)
config = Column(PickleType(pickler=pickle))
master_table_id = Column(Integer, ForeignKey("master.unique_id"))
parent = relationship("DBMasterTable", back_populates="children_config")
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DbConfigTable":
"""Create a DbConfigTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DbConfigTable: A DbConfigTable object.
"""
this = DbConfigTable()
this.unique_id = row["unique_id"]
this.timestamp = row["timestamp"]
this.strat = row["config"]
this.master_table_id = row["master_table_id"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DbConfigTable object.
Returns:
str: A string representation of the DbConfigTable object.
"""
return (
f"<DbStratTable(unique_id={self.unique_id})"
f", timestamp={self.timestamp} "
f", master_table_id={self.master_table_id})>"
)
[docs] @staticmethod
def update(engine: Engine) -> None:
"""Update the config_data table schema.
Args:
engine (Engine): The sqlalchemy engine.
"""
logger.info("DbConfigTable : update called")
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the config_data table schema requires an update.(always False)
Args:
engine (Engine): The sqlalchemy engine.
"""
return False
[docs]class DbRawTable(Base):
"""
Fact table to store the raw data of each iteration of an experiment.
"""
__tablename__ = "raw_data"
unique_id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(DateTime)
model_data = Column(Boolean)
master_table_id = Column(Integer, ForeignKey("master.unique_id"))
parent = relationship("DBMasterTable", back_populates="children_raw")
children_param = relationship("DbParamTable", back_populates="parent")
children_outcome = relationship("DbOutcomeTable", back_populates="parent")
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DbRawTable":
"""Create a DbRawTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DbRawTable: A DbRawTable object.
"""
this = DbRawTable()
this.unique_id = row["unique_id"]
this.timestamp = row["timestamp"]
this.model_data = row["model_data"]
this.master_table_id = row["master_table_id"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DbRawTable object.
Returns:
str: A string representation of the DbRawTable object.
"""
return (
f"<DbRawTable(unique_id={self.unique_id})"
f", timestamp={self.timestamp} "
f", master_table_id={self.master_table_id})>"
)
[docs] @staticmethod
def update(db: Any, engine: Engine) -> None:
"""Update the raw table with data from the replay table.
Args:
db (Any): The database object.
engine (Engine): The sqlalchemy engine.
"""
logger.info("DbRawTable : update called")
# Get every master table
for master_table in db.get_master_records():
# Get raw tab
for message in master_table.children_replay:
if message.message_type != "tell":
continue
timestamp = message.timestamp
# Deserialize pickle message
message_contents = message.message_contents
# Get outcome
outcomes = message_contents["message"]["outcome"]
# Get parameters
params = message_contents["message"]["config"]
# Get model_data
model_data = message_contents["message"].get("model_data", True)
db_raw_record = db.record_raw(
master_table=master_table,
model_data=bool(model_data),
timestamp=timestamp,
)
for param_name, param_value in params.items():
if isinstance(param_value, Iterable) and type(param_value) != str:
if len(param_value) == 1:
db.record_param(
raw_table=db_raw_record,
param_name=str(param_name),
param_value=float(param_value[0]),
)
else:
for j, v in enumerate(param_value):
db.record_param(
raw_table=db_raw_record,
param_name=str(param_name) + "_stimuli" + str(j),
param_value=float(v),
)
else:
db.record_param(
raw_table=db_raw_record,
param_name=str(param_name),
param_value=float(param_value),
)
if isinstance(outcomes, Iterable) and type(outcomes) != str:
for j, outcome_value in enumerate(outcomes):
if (
isinstance(outcome_value, Iterable)
and type(outcome_value) != str
):
if len(outcome_value) == 1:
outcome_value = outcome_value[0]
else:
raise ValueError(
"Multi-outcome values must be a list of lists of length 1!"
)
db.record_outcome(
raw_table=db_raw_record,
outcome_name="outcome_" + str(j),
outcome_value=float(outcome_value),
)
else:
db.record_outcome(
raw_table=db_raw_record,
outcome_name="outcome",
outcome_value=float(outcomes),
)
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the raw table is empty, and data already exists.
Args:
engine (Engine): The sqlalchemy engine.
Returns:
bool: True if the raw table is empty and data already exists, False otherwise.
"""
n_raws = engine.execute("SELECT COUNT (*) FROM raw_data").fetchone()[0]
n_tells = engine.execute(
"SELECT COUNT (*) FROM replay_data \
WHERE message_type = 'tell'"
).fetchone()[0]
if n_raws == 0 and n_tells != 0:
return True
return False
[docs]class DbParamTable(Base):
"""
Dimension table to store the parameters of each iteration of an experiment.
Supports multiple parameters per iteration, and multiple stimuli per parameter.
"""
__tablename__ = "param_data"
unique_id = Column(Integer, primary_key=True, autoincrement=True)
param_name = Column(String(50))
param_value = Column(String(50))
iteration_id = Column(Integer, ForeignKey("raw_data.unique_id"))
parent = relationship("DbRawTable", back_populates="children_param")
[docs] @classmethod
def from_sqlite(cls, row: Dict[str, Any]) -> "DbParamTable":
"""Create a DbParamTable object from a row in the sqlite database.
Args:
row (Dict[str, Any]): A row from the sqlite database.
Returns:
DbParamTable: A DbParamTable object.
"""
this = DbParamTable()
this.unique_id = row["unique_id"]
this.param_name = row["param_name"]
this.param_value = row["param_value"]
this.iteration_id = row["iteration_id"]
return this
def __repr__(self) -> str:
"""Return a string representation of the DbParamTable object.
Returns:
str: A string representation of the DbParamTable object.
"""
return (
f"<DbParamTable(unique_id={self.unique_id})"
f", iteration_id={self.iteration_id}>"
)
[docs] @staticmethod
def update(engine: Engine) -> None:
"""Update the param_data table schema.
Args:
engine (Engine): The sqlalchemy engine.
"""
logger.info("DbParamTable : update called")
[docs] @staticmethod
def requires_update(engine: Engine) -> bool:
"""Check if the param_data table schema requires an update.(always False)
Args:
engine (Engine): The sqlalchemy engine.
Returns:
bool: True if the param_data table schema requires an update, False otherwise.
"""
return False