Skip to content

Instantly share code, notes, and snippets.

@louislung
Last active January 29, 2024 10:37
Show Gist options
  • Select an option

  • Save louislung/cee96b28783056fc8d26a078b12e9ae3 to your computer and use it in GitHub Desktop.

Select an option

Save louislung/cee96b28783056fc8d26a078b12e9ae3 to your computer and use it in GitHub Desktop.
Keras implementation of RankNet

This gist is the implementation of RankNet using Keras Functional Api

For details please check this blog post.

keywords: learning to rank | tensorflow | keras | custom training loop | ranknet | lambdaRank

import tensorflow as tf
from tensorflow.keras import layers, activations, losses, Model, Input
from tensorflow.nn import leaky_relu
import numpy as np
from itertools import combinations
from tensorflow.keras.utils import plot_model, Progbar
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
# model architecture
class RankNet(Model):
def __init__(self):
super().__init__()
self.dense = [layers.Dense(16, activation=leaky_relu), layers.Dense(8, activation=leaky_relu)]
self.o = layers.Dense(1, activation='linear')
self.oi_minus_oj = layers.Subtract()
def call(self, inputs):
xi, xj = inputs
densei = self.dense[0](xi)
densej = self.dense[0](xj)
for dense in self.dense[1:]:
densei = dense(densei)
densej = dense(densej)
oi = self.o(densei)
oj= self.o(densej)
oij = self.oi_minus_oj([oi, oj])
output = layers.Activation('sigmoid')(oij)
return output
def build_graph(self):
x = [Input(shape=(10)), Input(shape=(10))]
return Model(inputs=x, outputs=self.call(x))
# visualize model architecture
plot_model(RankNet().build_graph(), show_shapes=False)
# generate data
nb_query = 20
query = np.array([i+1 for i in range(nb_query) for x in range(int(np.ceil(np.abs(np.random.normal(0,scale=15))+2)))])
doc_features = np.random.random((len(query), 10))
doc_scores = np.random.randint(5, size=len(query)).astype(np.float32)
# put data into pairs
xi = []
xj = []
pij = []
pair_id = []
pair_query_id = []
for q in np.unique(query):
query_idx = np.where(query == q)[0]
for pair_idx in combinations(query_idx, 2):
pair_query_id.append(q)
pair_id.append(pair_idx)
i = pair_idx[0]
j = pair_idx[1]
xi.append(doc_features[i])
xj.append(doc_features[j])
if doc_scores[i] == doc_scores[j]:
_pij = 0.5
elif doc_scores[i] > doc_scores[j]:
_pij = 1
else:
_pij = 0
pij.append(_pij)
xi = np.array(xi)
xj = np.array(xj)
pij = np.array(pij)
pair_query_id = np.array(pair_query_id)
xi_train, xi_test, xj_train, xj_test, pij_train, pij_test, pair_id_train, pair_id_test = train_test_split(
xi, xj, pij, pair_id, test_size=0.2, stratify=pair_query_id)
# train model using compile and fit
ranknet = RankNet()
ranknet.compile(optimizer='adam', loss='binary_crossentropy')
history = ranknet.fit([xi_train, xj_train], pij_train, epochs=50, batch_size=1, validation_data=([xi_test, xj_test], pij_test))
# function for plotting loss
def plot_metrics(train_metric, val_metric=None, metric_name=None, title=None, ylim=5):
plt.title(title)
plt.ylim(0,ylim)
plt.plot(train_metric,color='blue',label=metric_name)
if val_metric is not None: plt.plot(val_metric,color='green',label='val_' + metric_name)
plt.legend(loc="upper right")
# plot loss history
plot_metrics(history.history['loss'], history.history['val_loss'], "Loss", "Loss", ylim=1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment