Skip to content

Commit

Permalink
Merge pull request #414 from one-some/submit-ctx-menu
Browse files Browse the repository at this point in the history
Submit context menu
  • Loading branch information
henk717 authored Jul 29, 2023
2 parents 276efa6 + e2b3fa1 commit 21d2085
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 47 deletions.
57 changes: 39 additions & 18 deletions aiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import shutil
import eventlet

from modeling.inference_model import GenerationMode

eventlet.monkey_patch(all=True, thread=False, os=False)
import os, inspect, contextlib, pickle
os.system("")
Expand Down Expand Up @@ -1730,7 +1732,9 @@ def load_model(model_backend, initial_load=False):

with use_custom_unpickler(RestrictedUnpickler):
model = model_backends[model_backend]
koboldai_vars.supported_gen_modes = [x.value for x in model.get_supported_gen_modes()]
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)

koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
koboldai_vars.model = os.path.basename(os.path.normpath(model.path))
Expand Down Expand Up @@ -3209,7 +3213,16 @@ def check_for_backend_compilation():
break
koboldai_vars.checking = False

def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
def actionsubmit(
data,
actionmode=0,
force_submit=False,
force_prompt_gen=False,
disable_recentrng=False,
no_generate=False,
ignore_aibusy=False,
gen_mode=GenerationMode.STANDARD
):
# Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy):
return
Expand Down Expand Up @@ -3301,7 +3314,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
koboldai_vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
calcsubmit("") # Run the first action through the generator
calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = ""
force_submit = True
Expand Down Expand Up @@ -3367,7 +3380,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,

if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit("")
calcsubmit("", gen_mode=gen_mode)
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = ""
force_submit = True
Expand Down Expand Up @@ -3722,7 +3735,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
Expand Down Expand Up @@ -3764,7 +3777,7 @@ def calcsubmit(txt):
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))

start_time = time.time()
generate(subtxt, min, max, found_entries)
generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
logger.debug("Submit: generate time {}s".format(time.time()-start_time))
attention_bias.attention_bias = None

Expand Down Expand Up @@ -3832,7 +3845,7 @@ class HordeException(Exception):
# Send text to generator and deal with output
#==================================================================#

def generate(txt, minimum, maximum, found_entries=None):
def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
# Open up token stream
emit("stream_tokens", True, broadcast=True, room="UI_2")

Expand Down Expand Up @@ -3861,7 +3874,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator
try:
start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
Expand Down Expand Up @@ -6125,23 +6138,31 @@ def UI_2_delete_option(data):
@socketio.on('submit')
@logger.catch
def UI_2_submit(data):
if not koboldai_vars.noai and data['theme'] != "":
if not koboldai_vars.noai and data['theme']:
# Random prompt generation
logger.debug("doing random prompt")
memory = koboldai_vars.memory
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
koboldai_vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True, force_prompt_gen=True)
koboldai_vars.memory = memory
else:
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
if koboldai_vars.actions.action_count == -1:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
else:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)

return

logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None

gen_mode_name = data.get("gen_mode", None) or "standard"
try:
gen_mode = GenerationMode(gen_mode_name)
except ValueError:
# Invalid enum lookup!
gen_mode = GenerationMode.STANDARD
logger.warning(f"Unknown gen_mode '{gen_mode_name}', using STANDARD! Report this!")

actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)

#==================================================================#
# Event triggered when user clicks the submit button
#==================================================================#
Expand Down
1 change: 1 addition & 0 deletions koboldai_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def __init__(self, socketio, koboldai_vars):
self._koboldai_vars = koboldai_vars
self.alt_multi_gen = False
self.bit_8_available = None
self.supported_gen_modes = []

def reset_for_model_load(self):
self.simple_randomness = 0 #Set first as this affects other outputs
Expand Down
59 changes: 59 additions & 0 deletions modeling/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
import time
from typing import List, Optional, Union

from enum import Enum
from logger import logger

import torch
Expand All @@ -12,6 +14,7 @@
GPT2Tokenizer,
AutoTokenizer,
)
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors

Expand Down Expand Up @@ -144,7 +147,10 @@ def __init__(self, **overrides) -> None:
class ModelCapabilities:
embedding_manipulation: bool = False
post_token_hooks: bool = False

# Used to gauge if manual stopping is possible
stopper_hooks: bool = False

# TODO: Support non-live probabilities from APIs
post_token_probs: bool = False

Expand All @@ -154,6 +160,12 @@ class ModelCapabilities:
# Some models need to warm up the TPU before use
uses_tpu: bool = False

class GenerationMode(Enum):
STANDARD = "standard"
FOREVER = "forever"
UNTIL_EOS = "until_eos"
UNTIL_NEWLINE = "until_newline"
UNTIL_SENTENCE_END = "until_sentence_end"

class InferenceModel:
"""Root class for all models."""
Expand Down Expand Up @@ -256,13 +268,15 @@ def core_generate(
self,
text: list,
found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
):
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
Expand Down Expand Up @@ -358,6 +372,7 @@ def core_generate(
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
gen_mode=gen_mode
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
Expand Down Expand Up @@ -532,6 +547,7 @@ def raw_generate(
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
Expand All @@ -547,6 +563,7 @@ def raw_generate(
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises:
ValueError: If prompt type is weird
Expand All @@ -568,6 +585,29 @@ def raw_generate(
"wi_scanner_excluded_keys", set()
)

self.gen_state["allow_eos"] = False

temp_stoppers = []

if gen_mode not in self.get_supported_gen_modes():
gen_mode = GenerationMode.STANDARD
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")

if gen_mode == GenerationMode.FOREVER:
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_EOS:
self.gen_state["allow_eos"] = True
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)

self.stopper_hooks += temp_stoppers

utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))

Expand Down Expand Up @@ -604,6 +644,9 @@ def raw_generate(
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
)

for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)

return result

def generate(
Expand All @@ -620,3 +663,19 @@ def generate(
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
for hook in self.post_token_hooks:
hook(self, input_ids)

def get_supported_gen_modes(self) -> List[GenerationMode]:
"""Returns a list of compatible `GenerationMode`s for the current model.
Returns:
List[GenerationMode]: A list of compatible `GenerationMode`s.
"""
ret = [GenerationMode.STANDARD]

if self.capabilties.stopper_hooks:
ret += [
GenerationMode.FOREVER,
GenerationMode.UNTIL_NEWLINE,
GenerationMode.UNTIL_SENTENCE_END,
]
return ret
12 changes: 11 additions & 1 deletion modeling/inference_models/hf_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_models.hf import HFInferenceModel
from modeling.inference_model import (
GenerationMode,
GenerationResult,
GenerationSettings,
ModelCapabilities,
Expand Down Expand Up @@ -253,7 +254,10 @@ def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = KoboldLogitsWarperList()

if utils.koboldai_vars.newlinemode in ["s", "ns"]:
if (
utils.koboldai_vars.newlinemode in ["s", "ns"]
and not m_self.gen_state["allow_eos"]
):
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
Expand Down Expand Up @@ -604,3 +608,9 @@ def breakmodel_device_config(self, config):
self.breakmodel = False
self.usegpu = False
return

def get_supported_gen_modes(self) -> List[GenerationMode]:
# This changes a torch patch to disallow eos as a bad word.
return super().get_supported_gen_modes() + [
GenerationMode.UNTIL_EOS
]
Loading

0 comments on commit 21d2085

Please sign in to comment.