Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 67 additions & 36 deletions rehline/_mf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import _check_sample_weight
Expand Down Expand Up @@ -199,42 +198,22 @@ def __init__(
tol_CD=1e-4,
verbose=0,
):
# check input
errors = []
checks = [
(0 < rho < 1, "rho must be between 0 and 1"),
(C > 0, "C must be positive"),
(tol_CD > 0, "tol_CD must be positive"),
(tol > 0, "tol must be positive"),
]
for condition, error_msg in checks:
if not condition:
errors.append(error_msg)
if errors:
raise ValueError("; ".join(errors))

# parameter initialization
## -----------------------------basic perameters-----------------------------
## -----------------------------basic parameters-----------------------------
self.n_users = n_users
self.n_items = n_items
self.loss = loss
self.constraint_user = constraint_user if constraint_user is not None else []
self.constraint_item = constraint_item if constraint_item is not None else []
self.biased = biased
## -----------------------------hyper perameters-----------------------------
## -----------------------------hyper parameters-----------------------------
self.rank = rank
self.C = C
self.rho = rho
## --------------------------coefficient perameters--------------------------
## -------------------------initialization parameters------------------------
self.init_mean = init_mean
self.init_sd = init_sd
self.random_state = random_state
if self.random_state:
np.random.seed(random_state)
self.P = np.random.normal(loc=init_mean, scale=init_sd, size=(n_users, rank))
self.Q = np.random.normal(loc=init_mean, scale=init_sd, size=(n_items, rank))
self.bu = np.zeros(n_users) if self.biased else None
self.bi = np.zeros(n_items) if self.biased else None
## ----------------------------fitting parameters----------------------------
self.max_iter_CD = max_iter_CD
self.tol_CD = tol_CD
Expand Down Expand Up @@ -266,17 +245,62 @@ def fit(self, X, y, sample_weight=None):
An instance of the estimator.

"""
# check input
## parameter validation
errors = []
checks = [
(0 < self.rho < 1, "rho must be between 0 and 1"),
(self.C > 0, "C must be positive"),
(self.tol_CD > 0, "tol_CD must be positive"),
(self.tol > 0, "tol must be positive"),
]
for condition, error_msg in checks:
if not condition:
errors.append(error_msg)
if errors:
raise ValueError("; ".join(errors))

## data validation
X = np.asarray(X)
y = np.asarray(y)
if X.ndim != 2 or X.shape[1] != 2:
raise ValueError("X must have shape (n_ratings, 2)")
if X.shape[0] != len(y):
raise ValueError("X and y must have the same number of samples")
user_ids = X[:, 0].astype(int)
item_ids = X[:, 1].astype(int)
if np.any(user_ids < 0) or np.any(user_ids >= self.n_users):
raise ValueError("User IDs must be in [0, n_users)")
if np.any(item_ids < 0) or np.any(item_ids >= self.n_items):
raise ValueError("Item IDs must be in [0, n_items)")

# Preparation
self.n_ratings = len(y)
self.history = np.nan * np.zeros((self.max_iter_CD + 1, 2))
## number of training observations
self.n_ratings = len(y)
## convergence trace
self.history = np.full((self.max_iter_CD + 1, 2), np.nan)
## sample weights
self.sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

X_df = pd.DataFrame(X, columns=["user", "item"])
uidx_map = X_df.groupby("user").indices
iidx_map = X_df.groupby("item").indices
self.Iu = [uidx_map.get(u, np.array([], dtype=int)) for u in range(self.n_users)]
self.Ui = [iidx_map.get(i, np.array([], dtype=int)) for i in range(self.n_items)]

## random number generator
rng = np.random.default_rng(self.random_state)

## indices to locate interactions given a user or item id
### user side: Iu[u] = row indices of interactions by user u
sort_idx_users = np.argsort(X[:, 0], kind='stable')
sorted_users = X[sort_idx_users, 0]
counts = np.unique(sorted_users, return_counts=True)[1]
self.Iu = [np.array([], dtype=int) for _ in range(self.n_users)]
for u, idxs in zip(sorted_users[np.cumsum(counts) - counts], np.split(sort_idx_users, np.cumsum(counts)[:-1])):
self.Iu[u] = idxs
### item side: Ui[i] = row indices of interactions that involve item i
sort_idx_items = np.argsort(X[:, 1], kind='stable')
sorted_items = X[sort_idx_items, 1]
counts = np.unique(sorted_items, return_counts=True)[1]
self.Ui = [np.array([], dtype=int) for _ in range(self.n_items)]
for i, idxs in zip(sorted_items[np.cumsum(counts) - counts], np.split(sort_idx_items, np.cumsum(counts)[:-1])):
self.Ui[i] = idxs

## effective C when updating user/item blocks (to match rehline formulation: C * PLQ_loss + 0.5 * l_2)
C_user = self.C * self.n_users / (self.rho) / 2
C_item = self.C * self.n_items / (1 - self.rho) / 2

Expand All @@ -289,6 +313,12 @@ def fit(self, X, y, sample_weight=None):
)
)

# Model Initialization
self.P = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_users, self.rank))
self.Q = rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_items, self.rank))
self.bu = np.zeros(self.n_users) if self.biased else None
self.bi = np.zeros(self.n_items) if self.biased else None

# CD algorithm
self.history[0] = self.obj(X, y)
for iter_idx in range(self.max_iter_CD):
Expand Down Expand Up @@ -435,7 +465,7 @@ def fit(self, X, y, sample_weight=None):
obj = f"{self.history[iter_idx + 1][1]:.6f}"
print(f"{iter_idx + 1:<12} {mean_loss:<20} {obj:<20}")

if obj_diff < self.tol_CD:
if abs(obj_diff) < self.tol_CD:
break

return self
Expand Down Expand Up @@ -496,9 +526,10 @@ def obj(self, X, y):
item_penalty = np.sum(self.Q**2) * (1 - self.rho) / self.n_items
penalty = user_penalty + item_penalty

y_pred = self.decision_function(X)
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X, y=y)
X_dummy = np.ones((len(y), 1)) # not used in loss computation, only shape matters for loss param construction
U, V, Tau, S, T = _make_loss_rehline_param(loss=self.loss, X=X_dummy, y=y)
loss = ReHLoss(U, V, S, T, Tau)
y_pred = self.decision_function(X)
loss_term = loss(y_pred)

return loss_term, self.C * loss_term + penalty
Loading