diff --git a/aiserver.py b/aiserver.py index 821ecdd9d..db8304ef0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -12,6 +12,8 @@ import shutil import eventlet +from modeling.inference_model import GenerationMode + eventlet.monkey_patch(all=True, thread=False, os=False) import os, inspect, contextlib, pickle os.system("") @@ -1730,7 +1732,9 @@ def load_model(model_backend, initial_load=False): with use_custom_unpickler(RestrictedUnpickler): model = model_backends[model_backend] + koboldai_vars.supported_gen_modes = [x.value for x in model.get_supported_gen_modes()] model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel) + koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): koboldai_vars.model = os.path.basename(os.path.normpath(model.path)) @@ -3209,7 +3213,16 @@ def check_for_backend_compilation(): break koboldai_vars.checking = False -def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False): +def actionsubmit( + data, + actionmode=0, + force_submit=False, + force_prompt_gen=False, + disable_recentrng=False, + no_generate=False, + ignore_aibusy=False, + gen_mode=GenerationMode.STANDARD +): # Ignore new submissions if the AI is currently busy if(koboldai_vars.aibusy): return @@ -3301,7 +3314,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, koboldai_vars.prompt = data # Clear the startup text from game screen emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1") - calcsubmit("") # Run the first action through the generator + calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): data = "" force_submit = True @@ -3367,7 +3380,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating): # Off to the tokenizer! - calcsubmit("") + calcsubmit("", gen_mode=gen_mode) if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): data = "" force_submit = True @@ -3722,7 +3735,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, #==================================================================# # Take submitted text and build the text to be given to generator #==================================================================# -def calcsubmit(txt): +def calcsubmit(txt, gen_mode=GenerationMode.STANDARD): anotetxt = "" # Placeholder for Author's Note text forceanote = False # In case we don't have enough actions to hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth @@ -3764,7 +3777,7 @@ def calcsubmit(txt): logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time)) start_time = time.time() - generate(subtxt, min, max, found_entries) + generate(subtxt, min, max, found_entries, gen_mode=gen_mode) logger.debug("Submit: generate time {}s".format(time.time()-start_time)) attention_bias.attention_bias = None @@ -3832,7 +3845,7 @@ class HordeException(Exception): # Send text to generator and deal with output #==================================================================# -def generate(txt, minimum, maximum, found_entries=None): +def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD): # Open up token stream emit("stream_tokens", True, broadcast=True, room="UI_2") @@ -3861,7 +3874,7 @@ def generate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: start_time = time.time() - genout, already_generated = tpool.execute(model.core_generate, txt, found_entries) + genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) except Exception as e: if(issubclass(type(e), lupa.LuaError)): @@ -6125,23 +6138,31 @@ def UI_2_delete_option(data): @socketio.on('submit') @logger.catch def UI_2_submit(data): - if not koboldai_vars.noai and data['theme'] != "": + if not koboldai_vars.noai and data['theme']: + # Random prompt generation logger.debug("doing random prompt") memory = koboldai_vars.memory koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme']) koboldai_vars.lua_koboldbridge.feedback = None actionsubmit("", force_submit=True, force_prompt_gen=True) koboldai_vars.memory = memory - else: - logger.debug("doing normal input") - koboldai_vars.actions.clear_unused_options() - koboldai_vars.lua_koboldbridge.feedback = None - koboldai_vars.recentrng = koboldai_vars.recentrngm = None - if koboldai_vars.actions.action_count == -1: - actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) - else: - actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) - + return + + logger.debug("doing normal input") + koboldai_vars.actions.clear_unused_options() + koboldai_vars.lua_koboldbridge.feedback = None + koboldai_vars.recentrng = koboldai_vars.recentrngm = None + + gen_mode_name = data.get("gen_mode", None) or "standard" + try: + gen_mode = GenerationMode(gen_mode_name) + except ValueError: + # Invalid enum lookup! + gen_mode = GenerationMode.STANDARD + logger.warning(f"Unknown gen_mode '{gen_mode_name}', using STANDARD! Report this!") + + actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode) + #==================================================================# # Event triggered when user clicks the submit button #==================================================================# diff --git a/koboldai_settings.py b/koboldai_settings.py index bf824a7c5..db101b34c 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -685,6 +685,7 @@ def __init__(self, socketio, koboldai_vars): self._koboldai_vars = koboldai_vars self.alt_multi_gen = False self.bit_8_available = None + self.supported_gen_modes = [] def reset_for_model_load(self): self.simple_randomness = 0 #Set first as this affects other outputs diff --git a/modeling/inference_model.py b/modeling/inference_model.py index a2d4fa63b..8b7f0e3ec 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -3,6 +3,8 @@ from dataclasses import dataclass import time from typing import List, Optional, Union + +from enum import Enum from logger import logger import torch @@ -12,6 +14,7 @@ GPT2Tokenizer, AutoTokenizer, ) +from modeling.stoppers import Stoppers from modeling.tokenizer import GenericTokenizer from modeling import logits_processors @@ -144,7 +147,10 @@ def __init__(self, **overrides) -> None: class ModelCapabilities: embedding_manipulation: bool = False post_token_hooks: bool = False + + # Used to gauge if manual stopping is possible stopper_hooks: bool = False + # TODO: Support non-live probabilities from APIs post_token_probs: bool = False @@ -154,6 +160,12 @@ class ModelCapabilities: # Some models need to warm up the TPU before use uses_tpu: bool = False +class GenerationMode(Enum): + STANDARD = "standard" + FOREVER = "forever" + UNTIL_EOS = "until_eos" + UNTIL_NEWLINE = "until_newline" + UNTIL_SENTENCE_END = "until_sentence_end" class InferenceModel: """Root class for all models.""" @@ -256,6 +268,7 @@ def core_generate( self, text: list, found_entries: set, + gen_mode: GenerationMode = GenerationMode.STANDARD, ): """Generate story text. Heavily tied to story-specific parameters; if you are making a new generation-based feature, consider `generate_raw()`. @@ -263,6 +276,7 @@ def core_generate( Args: text (list): Encoded input tokens found_entries (set): Entries found for Dynamic WI + gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD Raises: RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check @@ -358,6 +372,7 @@ def core_generate( seed=utils.koboldai_vars.seed if utils.koboldai_vars.full_determinism else None, + gen_mode=gen_mode ) logger.debug( "core_generate: run raw_generate pass {} {}s".format( @@ -532,6 +547,7 @@ def raw_generate( found_entries: set = (), tpu_dynamic_inference: bool = False, seed: Optional[int] = None, + gen_mode: GenerationMode = GenerationMode.STANDARD, **kwargs, ) -> GenerationResult: """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. @@ -547,6 +563,7 @@ def raw_generate( is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. single_line (bool, optional): Generate one line only.. Defaults to False. found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). + gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD. Raises: ValueError: If prompt type is weird @@ -568,6 +585,29 @@ def raw_generate( "wi_scanner_excluded_keys", set() ) + self.gen_state["allow_eos"] = False + + temp_stoppers = [] + + if gen_mode not in self.get_supported_gen_modes(): + gen_mode = GenerationMode.STANDARD + logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!") + + if gen_mode == GenerationMode.FOREVER: + self.gen_state["stop_at_genamt"] = False + max_new = 1e7 + elif gen_mode == GenerationMode.UNTIL_EOS: + self.gen_state["allow_eos"] = True + self.gen_state["stop_at_genamt"] = False + max_new = 1e7 + elif gen_mode == GenerationMode.UNTIL_NEWLINE: + # TODO: Look into replacing `single_line` with `generation_mode` + temp_stoppers.append(Stoppers.newline_stopper) + elif gen_mode == GenerationMode.UNTIL_SENTENCE_END: + temp_stoppers.append(Stoppers.sentence_end_stopper) + + self.stopper_hooks += temp_stoppers + utils.koboldai_vars.inference_config.do_core = is_core gen_settings = GenerationSettings(*(generation_settings or {})) @@ -604,6 +644,9 @@ def raw_generate( f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second." ) + for stopper in temp_stoppers: + self.stopper_hooks.remove(stopper) + return result def generate( @@ -620,3 +663,19 @@ def generate( def _post_token_gen(self, input_ids: torch.LongTensor) -> None: for hook in self.post_token_hooks: hook(self, input_ids) + + def get_supported_gen_modes(self) -> List[GenerationMode]: + """Returns a list of compatible `GenerationMode`s for the current model. + + Returns: + List[GenerationMode]: A list of compatible `GenerationMode`s. + """ + ret = [GenerationMode.STANDARD] + + if self.capabilties.stopper_hooks: + ret += [ + GenerationMode.FOREVER, + GenerationMode.UNTIL_NEWLINE, + GenerationMode.UNTIL_SENTENCE_END, + ] + return ret \ No newline at end of file diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 548c1c4af..5a6b18c1a 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -34,6 +34,7 @@ from modeling.post_token_hooks import PostTokenHooks from modeling.inference_models.hf import HFInferenceModel from modeling.inference_model import ( + GenerationMode, GenerationResult, GenerationSettings, ModelCapabilities, @@ -253,7 +254,10 @@ def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None kwargs["logits_warper"] = KoboldLogitsWarperList() - if utils.koboldai_vars.newlinemode in ["s", "ns"]: + if ( + utils.koboldai_vars.newlinemode in ["s", "ns"] + and not m_self.gen_state["allow_eos"] + ): kwargs["eos_token_id"] = -1 kwargs.setdefault("pad_token_id", 2) return new_sample.old_sample(self, *args, **kwargs) @@ -604,3 +608,9 @@ def breakmodel_device_config(self, config): self.breakmodel = False self.usegpu = False return + + def get_supported_gen_modes(self) -> List[GenerationMode]: + # This changes a torch patch to disallow eos as a bad word. + return super().get_supported_gen_modes() + [ + GenerationMode.UNTIL_EOS + ] \ No newline at end of file diff --git a/modeling/stoppers.py b/modeling/stoppers.py index cdf5b6e2e..e4c20be63 100644 --- a/modeling/stoppers.py +++ b/modeling/stoppers.py @@ -3,15 +3,12 @@ import torch import utils -from modeling.inference_model import ( - InferenceModel, -) - +from modeling import inference_model class Stoppers: @staticmethod def core_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.inference_config.do_core: @@ -62,7 +59,7 @@ def core_stopper( @staticmethod def dynamic_wi_scanner( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.inference_config.do_dynamic_wi: @@ -93,7 +90,7 @@ def dynamic_wi_scanner( @staticmethod def chat_mode_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.chatmode: @@ -118,7 +115,7 @@ def chat_mode_stopper( @staticmethod def stop_sequence_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: @@ -154,14 +151,22 @@ def stop_sequence_stopper( @staticmethod def singleline_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: - """If singleline mode is enabled, it's pointless to generate output beyond the first newline.""" + """Stop on occurances of newlines **if singleline is enabled**.""" + # It might be better just to do this further up the line if not utils.koboldai_vars.singleline: return False + return Stoppers.newline_stopper(model, input_ids) + @staticmethod + def newline_stopper( + model: inference_model.InferenceModel, + input_ids: torch.LongTensor, + ) -> bool: + """Stop on occurances of newlines.""" # Keep track of presence of newlines in each sequence; we cannot stop a # batch member individually, so we must wait for all of them to contain # a newline. @@ -176,3 +181,30 @@ def singleline_stopper( del model.gen_state["newline_in_sequence"] return True return False + + @staticmethod + def sentence_end_stopper( + model: inference_model.InferenceModel, + input_ids: torch.LongTensor, + ) -> bool: + """Stops at the end of sentences.""" + + # TODO: Make this more robust + SENTENCE_ENDS = [".", "?", "!"] + + # We need to keep track of stopping for each batch, since we can't stop + # one individually. + if "sentence_end_in_sequence" not in model.gen_state: + model.gen_state["sentence_end_sequence"] = [False] * len(input_ids) + + for sequence_idx, batch_sequence in enumerate(input_ids): + decoded = model.tokenizer.decode(batch_sequence[-1]) + for end in SENTENCE_ENDS: + if end in decoded: + model.gen_state["sentence_end_sequence"][sequence_idx] = True + break + + if all(model.gen_state["sentence_end_sequence"]): + del model.gen_state["sentence_end_sequence"] + return True + return False \ No newline at end of file diff --git a/static/koboldai.css b/static/koboldai.css index 1c2ebef3b..a08cae4d2 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -2730,13 +2730,14 @@ body { #context-menu > hr { /* Division Color*/ border-top: 2px solid var(--context_menu_division); - margin: 5px 5px; + margin: 3px 5px; } .context-menu-item { - padding: 5px; + padding: 4px; padding-right: 25px; min-width: 100px; + white-space: nowrap; } .context-menu-item:hover { @@ -2747,11 +2748,16 @@ body { .context-menu-item > .material-icons-outlined { position: relative; - top: 2px; + top: 3px; font-size: 15px; margin-right: 5px; } +.context-menu-item > .context-menu-label { + position: relative; + top: 1px; +} + /* Substitutions */ #Substitutions { margin-left: 10px; diff --git a/static/koboldai.js b/static/koboldai.js index 87f2f944b..d2ecac0a4 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -85,6 +85,7 @@ let story_id = -1; var dirty_chunks = []; var initial_socketio_connection_occured = false; var selected_model_data; +var supported_gen_modes = []; var privacy_mode_enabled = false; var ai_busy = false; var can_show_options = false; @@ -162,7 +163,36 @@ const context_menu_actions = { "wi-img-upload-button": [ {label: "Upload Image", icon: "file_upload", enabledOn: "ALWAYS", click: wiImageReplace}, {label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage}, - ] + ], + "submit-button": [ + {label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()}, + null, + { + label: "Generate Forever", + icon: "edit_off", + enabledOn: () => supported_gen_modes.includes("forever"), + click: () => storySubmit("forever") + }, + { + label: "Generate Until EOS", + icon: "edit_off", + enabledOn: () => supported_gen_modes.includes("until_eos"), + click: () => storySubmit("until_eos") + }, + null, + { + label: "Finish Line", + icon: "edit_off", + enabledOn: () => supported_gen_modes.includes("until_newline"), + click: () => storySubmit("until_newline") + }, + { + label: "Finish Sentence", + icon: "edit_off", + enabledOn: () => supported_gen_modes.includes("until_sentence_end"), + click: () => storySubmit("until_sentence_end") + }, + ], }; let context_menu_cache = []; @@ -254,10 +284,17 @@ function disconnect() { document.getElementById("disconnect_message").classList.remove("hidden"); } -function storySubmit() { +function storySubmit(genMode=null) { + const textInput = document.getElementById("input_text"); + const themeInput = document.getElementById("themetext"); disruptStoryState(); - socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value}); - document.getElementById('input_text').value = ''; + socket.emit('submit', { + data: textInput.value, + theme: themeInput.value, + gen_mode: genMode, + }); + + textInput.value = ''; document.getElementById('themetext').value = ''; } @@ -1009,6 +1046,9 @@ function var_changed(data) { //special case for welcome text since we want to allow HTML } else if (data.classname == 'model' && data.name == 'welcome') { document.getElementById('welcome_text').innerHTML = data.value; + //Special case for permitted generation modes + } else if (data.classname == 'model' && data.name == 'supported_gen_modes') { + supported_gen_modes = data.value; //Basic Data Syncing } else { var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_")); @@ -5976,8 +6016,21 @@ function position_context_menu(contextMenu, x, y) { right: x + width, }; + // Slide over if running against the window bounds. if (farMenuBounds.right > bounds.right) x -= farMenuBounds.right - bounds.right; - if (farMenuBounds.bottom > bounds.bottom) y -= farMenuBounds.bottom - bounds.bottom; + + if (farMenuBounds.bottom > bounds.bottom) { + // We've hit the bottom. + + // The old algorithm pushed the menu against the wall, similar to what's + // done on the x-axis: + // y -= farMenuBounds.bottom - bounds.bottom; + // But now, we make the box change its emission direction from the cursor: + y -= (height + 5); + // The main advantage of this approach is that the cursor is never directly + // placed above a context menu item immediately after activating the context + // menu. (Thus the 5px offset also added) + } contextMenu.style.left = `${x}px`; contextMenu.style.top = `${y}px`; @@ -6252,21 +6305,23 @@ process_cookies(); continue; } + const enableCriteriaIsFunction = typeof action.enabledOn === "function" - let item = $e("div", contextMenu, { + const itemEl = $e("div", contextMenu, { classes: ["context-menu-item", "noselect", `context-menu-${key}`], - "enabled-on": action.enabledOn, + "enabled-on": enableCriteriaIsFunction ? "CALLBACK" : action.enabledOn, "cache-index": context_menu_cache.length }); + itemEl.enabledOnCallback = action.enabledOn; context_menu_cache.push({shouldShow: action.shouldShow}); - let icon = $e("span", item, {classes: ["material-icons-outlined"], innerText: action.icon}); - item.append(action.label); + const icon = $e("span", itemEl, {classes: ["material-icons-outlined"], innerText: action.icon}); + $e("span", itemEl, {classes: ["context-menu-label"], innerText: action.label}); - item.addEventListener("mousedown", e => e.preventDefault()); + itemEl.addEventListener("mousedown", e => e.preventDefault()); // Expose the "summonEvent" to enable access to original context menu target. - item.addEventListener("click", () => action.click(summonEvent)); + itemEl.addEventListener("click", () => action.click(summonEvent)); } } @@ -6289,6 +6344,10 @@ process_cookies(); // Show only applicable actions in the context menu let contextMenuType = target.getAttribute("context-menu"); + + // If context menu is not present, return + if (!context_menu_actions[contextMenuType]) return; + for (const contextMenuItem of contextMenu.childNodes) { let shouldShow = contextMenuItem.classList.contains(`context-menu-${contextMenuType}`); @@ -6316,10 +6375,10 @@ process_cookies(); // Disable non-applicable items $(".context-menu-item").addClass("disabled"); - + // A selection is made if (getSelectionText()) $(".context-menu-item[enabled-on=SELECTION]").removeClass("disabled"); - + // The caret is placed if (get_caret_position(target) !== null) $(".context-menu-item[enabled-on=CARET]").removeClass("disabled"); @@ -6328,6 +6387,11 @@ process_cookies(); $(".context-menu-item[enabled-on=ALWAYS]").removeClass("disabled"); + for (const contextMenuItem of document.querySelectorAll(".context-menu-item[enabled-on=CALLBACK]")) { + if (!contextMenuItem.enabledOnCallback()) continue; + contextMenuItem.classList.remove("disabled"); + } + // Make sure hr isn't first or last visible element let visibles = []; for (const item of contextMenu.children) { diff --git a/templates/index_new.html b/templates/index_new.html index 50fb02816..64c4c76b2 100644 --- a/templates/index_new.html +++ b/templates/index_new.html @@ -110,9 +110,9 @@ - + - +