From b10a978b00c1bf0917f39e9934d0e4025309cb22 Mon Sep 17 00:00:00 2001 From: Sreyan Ghosh <60854658+sreyan-ghosh@users.noreply.github.com> Date: Wed, 29 Sep 2021 13:26:10 +0000 Subject: [PATCH] updated Model.py (#20) --- blobcity/store/Model.py | 60 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/blobcity/store/Model.py b/blobcity/store/Model.py index 9075b34..7e00f19 100644 --- a/blobcity/store/Model.py +++ b/blobcity/store/Model.py @@ -38,7 +38,7 @@ def parameters(self): """ return: Dictionary - funciton return dictionary consisting of tuned parameters value for the trained model. + function return dictionary consisting of tuned parameters value for the trained model. """ return self.params @@ -51,7 +51,57 @@ def features(self): return self.featureList def save(self, path_pref='./'): - final_path = os.path.join(path_pref, 'my_model.pkl') - pickle.dumps(final_path) - # print(final_path) - return final_path \ No newline at end of file + """ + param: Path Prefix or Entire Path. Supported formats are .pkl and .h5. Default is .pkl + returns: Final filepath of stored serialized file + + function saves the model and its weights serially and returns the filepath where it is saved. + """ + path_components = path_pref.split('.') + if len(path_components)<=2: + extension = path_components[1] + else: + extension = path_components[2] + + if extension == '/': + final_path = os.path.join(path_pref, 'my_model.pkl') + pickle.dump(self.model, open(final_path, 'wb')) + print("The model is stored at {}".format(final_path)) + return final_path + + elif extension == 'pkl': + final_path = path_pref + pickle.dump(self.model, open(final_path, 'wb')) + print("The model is stored at {}".format(final_path)) + return final_path + + elif extension == 'h5': + final_path = path_pref + try: + self.model.save(final_path) + print("The model is stored at {}".format(final_path)) + return final_path + except: + raise TypeError("Your model is not a Keras model of type .h5. Try .pkl extension.") + + else: + raise TypeError(f"{extension} file type must be .pkl or .h5") + + def load(self, filepath): + """ + param: (required) the filepath to the stored model. Supports .h5 or .pkl models. + returns: Model file + + function loads the serialized model from .pkl or .h5 format to usable format. + """ + path_components = path_pref.split('.') + if len(path_components)<=2: + extension = path_components[1] + else: + extension = path_components[2] + + if extension == 'pkl': + self.model = pickle.load(open(filepath, 'rb')) + elif extension == 'h5': + self.model = tf.keras.models.load_model(filepath) + return self.model \ No newline at end of file