在这篇文章中,我们将深入探讨一种有趣的深度学习(DL)新方法,称为关系深度学习(RDL)。我们还将通过在电子商务公司的实际数据库(不是数据集)上进行一些 RDL 来获得一些实践经验。
介绍
在现实世界中,我们通常有一个关系数据库,希望对其执行某些机器学习任务。但是,特别是当数据库高度规范化时,这意味着需要进行大量耗时的特征工程,并且在必须进行许多聚合时,会损失粒度。更重要的是,存在无数可能的特征组合,每一种都可能产生良好的性能。这意味着我们可能会遗漏一些与机器学习任务相关的信息。
这与深度学习神经网络出现之前的计算机视觉早期类似,当时特征是从像素值中手工提取的。如今,模型直接处理原始像素,而不再依赖这一中间层。
关系深度学习
RDL为表格学习做同样的事情。也就是说,它去除了构建特征矩阵的额外步骤,通过直接在关系数据库上学习来实现。通过将数据库及其关系转换为图来实现这一点,其中表中的一行变成节点,表之间的关系变成边。行值作为节点特征存储在节点内部。
在这篇文章中,我们将使用来自kaggle的这个电子商务数据集,它包含了一个电子商务平台交易数据,这些数据采用星型模式,具有一个中心事实表(transactions)和一些维度表。
在这篇文章中,我们将使用relbench库来进行RDL。在relbench中,我们要做的第一件事是指定我们关系数据库的架构。以下是我们如何为数据库中的“transactions”表执行此操作的示例。我们给出一个pandas数据框作为表,并指定主键和时间戳列。主键列用于唯一标识实体。时间戳确保我们只能从过去的交易中学习,当我们想要预测未来的交易时。在图中,这意味着信息只能从具有较低时间戳(即在过去)的节点流向具有较高时间戳的节点。此外,我们指定了关系中存在的外键。在这种情况下,“transactions”表有一个“customer_key”列,这是一个外键,指向“customer_dim”表。
tables['transactions'] = Table(
df=pd.DataFrame(t),
pkey_col='t_id',
fkey_col_to_pkey_table={
'customer_key': 'customers',
'item_key': 'products',
'store_key': 'stores'
},
time_col='date'
)
其他表也需要以同样的方式定义。请注意,如果你已经有一个数据库架构,这个过程也可以自动化。由于数据集来自Kaggle,我需要手动创建架构。我们还需要将日期列转换为实际的pandas datetime对象,并移除任何NaN值。
class EcommerceDataBase(Dataset):
# example of creating your own dataset: https://github.com/snap-stanford/relbench/blob/main/tutorials/custom_dataset.ipynb
val_timestamp = pd.Timestamp(year=2018, month=1, day=1)
test_timestamp = pd.Timestamp(year=2020, month=1, day=1)
def make_db(self) -> Database:
tables = {}
customers = load_csv_to_db(BASE_DIR + '/customer_dim.csv').drop(columns=['contact_no', 'nid']).rename(columns={'coustomer_key': 'customer_key'})
stores = load_csv_to_db(BASE_DIR + '/store_dim.csv').drop(columns=['upazila'])
products = load_csv_to_db(BASE_DIR + '/item_dim.csv')
transactions = load_csv_to_db(BASE_DIR + '/fact_table.csv').rename(columns={'coustomer_key': 'customer_key'})
times = load_csv_to_db(BASE_DIR + '/time_dim.csv')
t = transactions.merge(times[['time_key', 'date']], on='time_key').drop(columns=['payment_key', 'time_key', 'unit'])
t['date'] = pd.to_datetime(t.date)
t = t.reset_index().rename(columns={'index': 't_id'})
t['quantity'] = t.quantity.astype(int)
t['unit_price'] = t.unit_price.astype(float)
products['unit_price'] = products.unit_price.astype(float)
t['total_price'] = t.total_price.astype(float)
print(t.isna().sum(axis=0))
print(products.isna().sum(axis=0))
print(stores.isna().sum(axis=0))
print(customers.isna().sum(axis=0))
tables['products'] = Table(
df=pd.DataFrame(products),
pkey_col='item_key',
fkey_col_to_pkey_table={},
time_col=None
)
tables['customers'] = Table(
df=pd.DataFrame(customers),
pkey_col='customer_key',
fkey_col_to_pkey_table={},
time_col=None
)
tables['transactions'] = Table(
df=pd.DataFrame(t),
pkey_col='t_id',
fkey_col_to_pkey_table={
'customer_key': 'customers',
'item_key': 'products',
'store_key': 'stores'
},
time_col='date'
)
tables['stores'] = Table(
df=pd.DataFrame(stores),
pkey_col='store_key',
fkey_col_to_pkey_table={}
)
return Database(tables)
至关重要的是,作者引入了训练表的概念。这个训练表本质上定义了机器学习任务。这里的想法是,我们想要预测数据库中某个实体的未来状态(即未来值)。我们通过指定一个表来实现这一点,其中每行都有一个时间戳、实体的标识符以及我们想要预测的一些值。标识符用于指定实体,时间戳指定我们需要预测实体的哪个时间点。这也会限制可用于推断该实体值的数据(即仅限过去的数据)。值本身就是我们想要预测的内容(即真实值)。
在我们的案例中,我们有一个拥有客户的在线平台。我们想要预测客户在未来30天内的收入。我们可以使用DuckDB执行的SQL语句来创建训练表。这是RDL的巨大优势,因为我们只需要使用SQL就可以创建任何类型的机器学习任务。例如,我们可以定义一个查询来选择买家在未来30天内的购买数量,以进行流失预测。
df = duckdb.sql(f"""
select
timestamp,
customer_key,
sum(total_price) as revenue
from
timestamp_df t
left join
transactions ta
on
ta.date <= t.timestamp + INTERVAL '{self.timedelta}'
and ta.date > t.timestamp
group by timestamp, customer_key
""").df().dropna()
结果将是一个表格,其中seller_id作为我们想要预测的实体的键,revenue作为目标值,timestamp作为我们需要进行预测的时间点(即我们只能使用到这个时间点的数据来进行预测)。
以下是创建“customer_revenue”任务的完整代码。
class CustomerRevenueTask(EntityTask):
# example of custom task: https://github.com/snap-stanford/relbench/blob/main/tutorials/custom_task.ipynb
task_type = TaskType.REGRESSION
entity_col = "customer_key"
entity_table = "customers"
time_col = "timestamp"
target_col = "revenue"
timedelta = pd.Timedelta(days=30) # how far we want to predict revenue into the future.
metrics = [r2, mae]
num_eval_timestamps = 40
def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
timestamp_df = pd.DataFrame({"timestamp": timestamps})
transactions = db.table_dict["transactions"].df
df = duckdb.sql(f"""
select
timestamp,
customer_key,
sum(total_price) as revenue
from
timestamp_df t
left join
transactions ta
on
ta.date <= t.timestamp + INTERVAL '{self.timedelta}'
and ta.date > t.timestamp
group by timestamp, customer_key
""").df().dropna()
print(df)
return Table(
df=df,
fkey_col_to_pkey_table={self.entity_col: self.entity_table},
pkey_col=None,
time_col=self.time_col,
)
至此,我们已经完成了大部分工作。剩余的工作流程将是类似的,与具体的机器学习任务无关。
例如,我们需要对节点特征进行编码。在这里,我们可以使用glove嵌入来编码所有的文本特征,如产品描述和产品名称。
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor
class GloveTextEmbedding:
def __init__(self, device: Optional[torch.device
] = None):
self.model = SentenceTransformer(
"sentence-transformers/average_word_embeddings_glove.6B.300d",
device=device,
)
def __call__(self, sentences: List[str]) -> Tensor:
return torch.from_numpy(self.model.encode(sentences))
之后,我们可以将这些转换应用到我们的数据中,并构建出图结构。
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph
text_embedder_cfg = TextEmbedderConfig(
text_embedder=GloveTextEmbedding(device=device), batch_size=256
)
data, col_stats_dict = make_pkey_fkey_graph(
db,
col_to_stype_dict=col_to_stype_dict, # speficied column types
text_embedder_cfg=text_embedder_cfg, # our chosen text encoder
cache_dir=os.path.join(
root_dir, f"rel-ecomm_materialized_cache"
), # store materialized graph for convenience
)
余下的代码将包括使用标准层构建图神经网络(GNN)、编写训练循环以及进行一些评估。为了简洁起见,我将不在这篇文章中列出这些代码,因为它们非常标准,并且在不同的任务中都是相同的。
因此,我们可以训练这个图神经网络(GNN),使其r2达到约0.3,平均绝对误差(MAE)达到500。这意味着它预测卖家未来30天的收入,平均误差为±500美元。当然,我们无法判断这是好是坏,也许通过传统的机器学习和特征工程的结合,我们可以获得80%的r2。
结论
关系深度学习(Relational Deep Learning)是一种有趣的机器学习新方法,特别是当我们拥有一个复杂的关系架构,而手动特征工程又过于繁琐时。它使我们能够仅使用SQL来定义机器学习任务,这对于那些不深入研究数据科学但了解一些SQL的人来说尤其有用。这也意味着我们可以快速迭代,并尝试许多不同的任务。
同时,这种方法也带来了自己的问题,如训练GNN的难度以及从关系架构中构建图的复杂性。此外,还有一个问题是关系深度学习在性能上能在多大程度上与传统机器学习模型竞争。过去,我们已经看到像XGboost这样的模型在表格预测问题上证明比神经网络更好。