Skip to content

Commit

Permalink
updated Model.py (blobcity#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyan-ghosh authored Sep 29, 2021
1 parent a598f23 commit b10a978
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions blobcity/store/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parameters(self):
"""
return: Dictionary
funciton return dictionary consisting of tuned parameters value for the trained model.
function return dictionary consisting of tuned parameters value for the trained model.
"""
return self.params

Expand All @@ -51,7 +51,57 @@ def features(self):
return self.featureList

def save(self, path_pref='./'):
final_path = os.path.join(path_pref, 'my_model.pkl')
pickle.dumps(final_path)
# print(final_path)
return final_path
"""
param: Path Prefix or Entire Path. Supported formats are .pkl and .h5. Default is .pkl
returns: Final filepath of stored serialized file
function saves the model and its weights serially and returns the filepath where it is saved.
"""
path_components = path_pref.split('.')
if len(path_components)<=2:
extension = path_components[1]
else:
extension = path_components[2]

if extension == '/':
final_path = os.path.join(path_pref, 'my_model.pkl')
pickle.dump(self.model, open(final_path, 'wb'))
print("The model is stored at {}".format(final_path))
return final_path

elif extension == 'pkl':
final_path = path_pref
pickle.dump(self.model, open(final_path, 'wb'))
print("The model is stored at {}".format(final_path))
return final_path

elif extension == 'h5':
final_path = path_pref
try:
self.model.save(final_path)
print("The model is stored at {}".format(final_path))
return final_path
except:
raise TypeError("Your model is not a Keras model of type .h5. Try .pkl extension.")

else:
raise TypeError(f"{extension} file type must be .pkl or .h5")

def load(self, filepath):
"""
param: (required) the filepath to the stored model. Supports .h5 or .pkl models.
returns: Model file
function loads the serialized model from .pkl or .h5 format to usable format.
"""
path_components = path_pref.split('.')
if len(path_components)<=2:
extension = path_components[1]
else:
extension = path_components[2]

if extension == 'pkl':
self.model = pickle.load(open(filepath, 'rb'))
elif extension == 'h5':
self.model = tf.keras.models.load_model(filepath)
return self.model

0 comments on commit b10a978

Please sign in to comment.