Skip to content

Commit

Permalink
Merge branch 'staging' into laserprec/code-cov-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
anargyri authored Mar 9, 2022
2 parents 0fa2ccf + 60426b4 commit c99afa7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
6 changes: 3 additions & 3 deletions recommenders/models/sar/sar_singlenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def fit(self, df):
self.item_frequencies = item_cooccurrence.diagonal()

logger.info("Calculating item similarity")
if self.similarity_type is COOCCUR:
if self.similarity_type == COOCCUR:
logger.info("Using co-occurrence based similarity")
self.item_similarity = item_cooccurrence
elif self.similarity_type is JACCARD:
elif self.similarity_type == JACCARD:
logger.info("Using jaccard based similarity")
self.item_similarity = jaccard(item_cooccurrence).astype(
df[self.col_rating].dtype
)
elif self.similarity_type is LIFT:
elif self.similarity_type == LIFT:
logger.info("Using lift based similarity")
self.item_similarity = lift(item_cooccurrence).astype(
df[self.col_rating].dtype
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/recommenders/models/test_sar_singlenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import codecs
import csv
import itertools
import json
import pytest
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -389,3 +390,26 @@ def test_get_normalized_scores(header):
assert actual.shape == (2, 7)
assert isinstance(actual, np.ndarray)
assert np.isclose(expected, np.asarray(actual)).all()


def test_match_similarity_type_from_json_file(header):
# store parameters in json
params_str = json.dumps({'similarity_type': 'lift'})
# load parameters in json
params = json.loads(params_str)

params.update(header)

model = SARSingleNode(**params)

train = pd.DataFrame(
{
header["col_user"]: [1, 1, 1, 1, 2, 2, 2, 2],
header["col_item"]: [1, 2, 3, 4, 1, 5, 6, 7],
header["col_rating"]: [3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0, 5.0],
header["col_timestamp"]: [1, 20, 30, 400, 50, 60, 70, 800],
}
)

# make sure fit still works when similarity type is loaded from a json file
model.fit(train)

0 comments on commit c99afa7

Please sign in to comment.