diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 00000000000..ef047dcc5d8 --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,316 @@ +import itertools +from typing import Sequence, Mapping + +import nodes + +from comfy.graph_utils import is_link + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self): + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + + def include_node_id_in_input(self): + return True + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.dynprompt = None + self.cache_key_set = None + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + + def all_node_ids(self): + assert self.cache_key_set is not None + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def clean_unused(self): + assert self.cache_key_set is not None + preserve_keys = set(self.cache_key_set.get_used_keys()) + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def _set_immediate(self, node_id, value): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + + def all_active_values(self): + active_nodes = self.all_node_ids() + result = [] + for node_id in active_nodes: + value = self.get(node_id) + if value is not None: + result.append(value) + return result + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + print("LRUCache: Now at generation %d" % self.generation) + + def clean_unused(self): + print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size)) + while len(self.cache) > self.max_size and self.min_generation < self.generation: + print("LRUCache: Evicting generation %d" % self.min_generation) + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + + def all_active_values(self): + explored = set() + to_explore = set(self.cache_key_set.get_used_keys()) + while len(to_explore) > 0: + cache_key = to_explore.pop() + if cache_key not in explored: + self.used_generation[cache_key] = self.generation + explored.add(cache_key) + if cache_key in self.children: + to_explore.update(self.children[cache_key]) + return [self.cache[key] for key in explored if key in self.cache] + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4bbfbfab53..2cbefefebd9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -87,6 +87,10 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") +cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 00000000000..2612317f9e6 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,172 @@ +import nodes + +from comfy.graph_utils import is_link + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + return None + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + value = inputs[to_input] + if not is_link(value): + raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, +# it can still be returned to the graph after having further dependencies added. +class ExecutionList(TopologicalSort): + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None + available = self.get_ready_nodes() + if len(available) == 0: + raise Exception("Dependency cycle detected") + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + +# Return this from a node and any users will be blocked with the given error message. +class ExecutionBlocker: + def __init__(self, message): + self.message = message + diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 00000000000..a0042e078f7 --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,140 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + if graph_index is not None: + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = "%s.%d.%d." % (root, call_index, graph_index) + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/custom_nodes/execution-inversion-demo-comfyui/__init__.py b/custom_nodes/execution-inversion-demo-comfyui/__init__.py new file mode 100644 index 00000000000..c872ceba708 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/__init__.py @@ -0,0 +1,25 @@ +from .nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPINGS +from .components import setup_js, COMPONENT_NODE_CLASS_MAPPINGS, COMPONENT_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .utility_nodes import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) + +setup_js() + diff --git a/custom_nodes/execution-inversion-demo-comfyui/components.py b/custom_nodes/execution-inversion-demo-comfyui/components.py new file mode 100644 index 00000000000..fb017db3dce --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/components.py @@ -0,0 +1,216 @@ +import os +import shutil +import folder_paths +import json +import copy +import comfy.graph_utils + +comfy_path = os.path.dirname(folder_paths.__file__) +js_path = os.path.join(comfy_path, "web", "extensions") +inversion_demo_path = os.path.dirname(__file__) + +def setup_js(): + # setup js + js_dest_path = os.path.join(js_path, "inversion-demo-components") + if not os.path.exists(js_dest_path): + os.makedirs(js_dest_path) + js_src_path = os.path.join(inversion_demo_path, "js", "inversion-demo-components.js") + shutil.copy(js_src_path, js_dest_path) + +class ComponentInput: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "name": ("STRING", {"multiline": False}), + "data_type": ("STRING", {"multiline": False, "default": "IMAGE"}), + "extra_args": ("STRING", {"multiline": False}), + "explicit_input_order": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), + "optional": ([False, True],), + }, + "optional": { + "default_value": ("*",), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "component_input" + + CATEGORY = "InversionDemo Nodes/Component Creation" + + def component_input(self, name, data_type, extra_args, explicit_input_order, optional, default_value = None): + return (default_value,) + +class ComponentOutput: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), + "data_type": ("STRING", {"multiline": False, "default": "IMAGE"}), + "name": ("STRING", {"multiline": False}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "component_output" + + CATEGORY = "InversionDemo Nodes/Component Creation" + + def component_output(self, index, data_type, name, value): + return (value,) + +class ComponentMetadata: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "name": ("STRING", {"multiline": False}), + "always_output": ([False, True],), + }, + } + + RETURN_TYPES = () + FUNCTION = "nop" + + CATEGORY = "InversionDemo Nodes/Component Creation" + + def nop(self, name): + return {} + +COMPONENT_NODE_CLASS_MAPPINGS = { + "ComponentInput": ComponentInput, + "ComponentOutput": ComponentOutput, + "ComponentMetadata": ComponentMetadata, +} +COMPONENT_NODE_DISPLAY_NAME_MAPPINGS = { + "ComponentInput": "Component Input", + "ComponentOutput": "Component Output", + "ComponentMetadata": "Component Metadata", +} + +DEFAULT_EXTRA_DATA = { + "STRING": {"multiline": False}, + "INT": {"default": 0, "min": 0, "max": 1000, "step": 1}, + "FLOAT": {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}, +} + +def default_extra_data(data_type, extra_args): + if data_type == "STRING": + args = {"multiline": False} + elif data_type == "INT": + args = {"default": 0, "min": -1000000, "max": 1000000, "step": 1} + elif data_type == "FLOAT": + args = {"default": 0.0, "min": -1000000.0, "max": 1000000.0, "step": 0.1} + else: + args = {} + args.update(extra_args) + return args + +def LoadComponent(component_file): + try: + with open(component_file, "r") as f: + component_data = f.read() + graph = json.loads(component_data)["output"] + + component_raw_name = os.path.basename(component_file).split(".")[0] + component_display_name = component_raw_name + component_inputs = [] + component_outputs = [] + is_output_component = False + for node_id, data in graph.items(): + if data["class_type"] == "ComponentMetadata": + component_display_name = data["inputs"].get("name", component_raw_name) + is_output_component = data["inputs"].get("always_output", False) + elif data["class_type"] == "ComponentInput": + data_type = data["inputs"]["data_type"] + if len(data_type) > 0 and data_type[0] == "[": + try: + data_type = json.loads(data_type) + except: + pass + try: + extra_args = json.loads(data["inputs"]["extra_args"]) + except: + extra_args = {} + component_inputs.append({ + "node_id": node_id, + "name": data["inputs"]["name"], + "data_type": data_type, + "extra_args": extra_args, + "explicit_input_order": data["inputs"]["explicit_input_order"], + "optional": data["inputs"]["optional"], + }) + elif data["class_type"] == "ComponentOutput": + component_outputs.append({ + "node_id": node_id, + "name": data["inputs"]["name"] or data["inputs"]["data_type"], + "index": data["inputs"]["index"], + "data_type": data["inputs"]["data_type"], + }) + component_inputs.sort(key=lambda x: (x["explicit_input_order"], x["name"])) + component_outputs.sort(key=lambda x: x["index"]) + for i in range(1, len(component_inputs)): + if component_inputs[i]["name"] == component_inputs[i-1]["name"]: + raise Exception("Component input name is not unique: {}".format(component_inputs[i]["name"])) + for i in range(1, len(component_outputs)): + if component_outputs[i]["index"] == component_outputs[i-1]["index"]: + raise Exception("Component output index is not unique: {}".format(component_outputs[i]["index"])) + except Exception as e: + print("Error loading component file: {}: {}".format(component_file, e)) + return None + + class ComponentNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if not node["optional"]}, + "optional": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if node["optional"]}, + } + + RETURN_TYPES = tuple([node["data_type"] for node in component_outputs]) + RETURN_NAMES = tuple([node["name"] for node in component_outputs]) + FUNCTION = "expand_component" + + CATEGORY = "Custom Components" + OUTPUT_NODE = is_output_component + + def expand_component(self, **kwargs): + new_graph = copy.deepcopy(graph) + for input_node in component_inputs: + if input_node["name"] in kwargs: + new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]] + outputs = tuple([[node["node_id"], 0] for node in component_outputs]) + new_graph, outputs = comfy.graph_utils.add_graph_prefix(new_graph, outputs, comfy.graph_utils.GraphBuilder.alloc_prefix()) + return { + "result": outputs, + "expand": new_graph, + } + ComponentNode.__name__ = component_raw_name + COMPONENT_NODE_CLASS_MAPPINGS[component_raw_name] = ComponentNode + COMPONENT_NODE_DISPLAY_NAME_MAPPINGS[component_raw_name] = component_display_name + print("Loaded component: {}".format(component_display_name)) + +def load_components(): + component_dir = os.path.join(comfy_path, "components") + if not os.path.exists(component_dir): + return + files = [f for f in os.listdir(component_dir) if os.path.isfile(os.path.join(component_dir, f)) and f.endswith(".json")] + for f in files: + print("Loading component file %s" % f) + LoadComponent(os.path.join(component_dir, f)) + +load_components() diff --git a/custom_nodes/execution-inversion-demo-comfyui/conditions.py b/custom_nodes/execution-inversion-demo-comfyui/conditions.py new file mode 100644 index 00000000000..3dc1a75f3c8 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/conditions.py @@ -0,0 +1,194 @@ +import re +import torch + +class IntConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "int_condition" + + CATEGORY = "InversionDemo Nodes/Logic" + + def int_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + + +class FloatConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "float_condition" + + CATEGORY = "InversionDemo Nodes/Logic" + + def float_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + +class StringConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("STRING", {"multiline": False}), + "b": ("STRING", {"multiline": False}), + "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],), + "case_sensitive": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "string_condition" + + CATEGORY = "InversionDemo Nodes/Logic" + + def string_condition(self, a, b, operation, case_sensitive): + if not case_sensitive: + a = a.lower() + b = b.lower() + + if operation == "a == b": + return (a == b,) + elif operation == "a != b": + return (a != b,) + elif operation == "a IN b": + return (a in b,) + elif operation == "a MATCH REGEX(b)": + try: + return (re.match(b, a) is not None,) + except: + return (False,) + elif operation == "a BEGINSWITH b": + return (a.startswith(b),) + elif operation == "a ENDSWITH b": + return (a.endswith(b),) + +class ToBoolNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + }, + "optional": { + "invert": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "to_bool" + + CATEGORY = "InversionDemo Nodes/Logic" + + def to_bool(self, value, invert = False): + if isinstance(value, torch.Tensor): + if value.max().item() == 0 and value.min().item() == 0: + result = False + else: + result = True + else: + try: + result = bool(value) + except: + # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer. + result = True + + if invert: + result = not result + + return (result,) + +class BoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOLEAN",), + "b": ("BOOLEAN",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "bool_operation" + + CATEGORY = "InversionDemo Nodes/Logic" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + +CONDITION_NODE_CLASS_MAPPINGS = { + "IntConditions": IntConditions, + "FloatConditions": FloatConditions, + "StringConditions": StringConditions, + "ToBoolNode": ToBoolNode, + "BoolOperationNode": BoolOperationNode, +} + +CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { + "IntConditions": "Int Condition", + "FloatConditions": "Float Condition", + "StringConditions": "String Condition", + "ToBoolNode": "To Bool", + "BoolOperationNode": "Bool Operation", +} diff --git a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py new file mode 100644 index 00000000000..1b328d957b2 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py @@ -0,0 +1,165 @@ +from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph import ExecutionBlocker + +NUM_FLOW_SOCKETS = 5 +class WhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + ["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_open" + + CATEGORY = "InversionDemo Nodes/Flow" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(["stub"] + values) + +class WhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {"forceInput": True}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_close" + + CATEGORY = "InversionDemo Nodes/Flow" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(values) + + # We want to loop + this_node = dynprompt.get_node(unique_id) + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + for i in range(NUM_FLOW_SOCKETS): + key = "initial_value%d" % i + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node("Recurse" ) + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +class ExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOLEAN",), + "verbose": ("BOOLEAN", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "InversionDemo Nodes/Flow" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "WhileLoopOpen": WhileLoopOpen, + "WhileLoopClose": WhileLoopClose, + "ExecutionBlocker": ExecutionBlockerNode, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "WhileLoopOpen": "While Loop Open", + "WhileLoopClose": "While Loop Close", + "ExecutionBlocker": "Execution Blocker", +} diff --git a/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js b/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js new file mode 100644 index 00000000000..cf72e670c35 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js @@ -0,0 +1,64 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; + +var update_comfyui_button = null; +var fetch_updates_button = null; + +const fileInput = $el("input", { + id: "component-file-input", + type: "file", + accept: ".json,image/png,.latent,.safetensors", + style: {display: "none"}, + parent: document.body, + onchange: async () => { + app.handleFile(fileInput.files[0]); + const reader = new FileReader(); + reader.onload = () => { + app.loadGraphData(JSON.parse(reader.result)["workflow"]); + }; + reader.readAsText(fileInput.files[0]); + }, +}); + +app.registerExtension({ + name: "Comfy.InversionDemoComponents", + + async setup() { + const menu = document.querySelector(".comfy-menu"); + const separator = document.createElement("hr"); + + separator.style.margin = "20px 0"; + separator.style.width = "100%"; + menu.append(separator); + + const saveButton = document.createElement("button"); + saveButton.textContent = "Save Component"; + saveButton.onclick = async () => { + let filename = "component.json"; + const p = await app.graphToPrompt(); + const json = JSON.stringify(p, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: filename, + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + }; + + const loadButton = document.createElement("button"); + loadButton.textContent = "Load Component"; + loadButton.onclick = () => { + fileInput.click(); + }; + menu.append(saveButton); + menu.append(loadButton); + } +}); diff --git a/custom_nodes/execution-inversion-demo-comfyui/nodes.py b/custom_nodes/execution-inversion-demo-comfyui/nodes.py new file mode 100644 index 00000000000..63230d4688f --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/nodes.py @@ -0,0 +1,299 @@ +import re + +from comfy.graph_utils import GraphBuilder + +class InversionDemoAdvancedPromptNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + "model": ("MODEL",), + "clip": ("CLIP",), + }, + } + + RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING") + FUNCTION = "advanced_prompt" + + CATEGORY = "InversionDemo Nodes/Demo" + + def parse_timesteps(self, text): + text = re.sub(r'<[^<]*>', lambda m: m.group(0).replace(':', '||COLON||'), text) + text = re.sub(r':\d+\.\d+\)', lambda m: m.group(0).replace(':', '||COLON||'), text) + def recurse(text, min_value, max_value): + # First, replace any colons in angle brackets with a placeholder + pattern = r'\[([^:\[\]]*):([^:\[\]]*):(\d+\.\d+)\]' + m = re.search(pattern, text) + if m is None: + return [{ + "text": re.sub(r'\|\|COLON\|\|', ':', text), + "min": min_value, + "max": max_value, + }] + + # If we have a match, check if the value is in range + start, end = m.span() + value = float(m.group(3)) + before = text[:start] + m.group(1) + text[end:] + after = text[:start] + m.group(2) + text[end:] + if value <= min_value: + return recurse(after, min_value, max_value) + elif value >= max_value: + return recurse(before, min_value, max_value) + else: + return recurse(before, min_value, value) + recurse(after, value, max_value) + return recurse(text, 0, 1) + + def parse_loras(self, prompt): + # Get all string pieces matching the pattern "" + # where name is a string and strength is a float + # and clip_strength is an optional float + pattern = r"" + loras = re.findall(pattern, prompt) + if len(loras) == 0: + return prompt, loras + cleaned_prompt = re.sub(pattern, "", prompt).strip() + return cleaned_prompt, loras + + + def advanced_prompt(self, prompt, clip, model): + graph = GraphBuilder() + cleaned_prompt, loras = self.parse_loras(prompt) + for lora in loras: + lora_name = lora[0] + lora_model_strength = float(lora[1]) + lora_clip_strength = lora_model_strength if lora[2] == "" else float(lora[2]) + + loader = graph.node("LoraLoader", model=model, clip=clip, lora_name = lora_name, strength_model = lora_model_strength, strength_clip = lora_clip_strength) + model = loader.out(0) + clip = loader.out(1) + + timesteps = self.parse_timesteps(cleaned_prompt) + prev_output = None + for timestep in timesteps: + encoder = graph.node("CLIPTextEncode", clip=clip, text=timestep["text"]) + ranger = graph.node("ConditioningSetTimestepRange", conditioning=encoder.out(0), start=timestep["min"], end=timestep["max"]) + if prev_output is None: + prev_output = ranger.out(0) + else: + prev_output = graph.node("ConditioningCombine", conditioning_1=prev_output, conditioning_2=ranger.out(0)).out(0) + + return { + "result": (model, clip, prev_output), + "expand": graph.finalize(), + } + +class InversionDemoFakeAdvancedPromptNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + "clip": ("CLIP",), + "model": ("MODEL",), + }, + } + + RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING") + FUNCTION = "advanced_prompt" + + CATEGORY = "InversionDemo Nodes/Debug" + + def advanced_prompt(self, prompt, clip, model): + tokens = clip.tokenize(prompt) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return (model, clip, [[cond, {"pooled_output": pooled}]]) + +class InversionDemoLazySwitch: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "switch": ("BOOLEAN",), + "on_false": ("*", {"lazy": True}), + "on_true": ("*", {"lazy": True}), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "switch" + + CATEGORY = "InversionDemo Nodes/Logic" + + def check_lazy_status(self, switch, on_false = None, on_true = None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + def switch(self, switch, on_false = None, on_true = None): + value = on_true if switch else on_false + return (value,) + +NUM_IF_ELSE_NODES = 10 +class InversionDemoLazyConditional: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + args = { + "value1": ("*", {"lazy": True}), + "condition1": ("BOOLEAN", {"forceInput": True}), + } + + for i in range(1,NUM_IF_ELSE_NODES): + args["value%d" % (i + 1)] = ("*", {"lazy": True}) + args["condition%d" % (i + 1)] = ("BOOLEAN", {"lazy": True, "forceInput": True}) + + args["else"] = ("*", {"lazy": True}) + + return { + "required": {}, + "optional": args, + } + + RETURN_TYPES = ("*",) + FUNCTION = "conditional" + + CATEGORY = "InversionDemo Nodes/Logic" + + def check_lazy_status(self, **kwargs): + for i in range(0,NUM_IF_ELSE_NODES): + cond = "condition%d" % (i + 1) + if cond not in kwargs: + return [cond] + if kwargs[cond]: + val = "value%d" % (i + 1) + if val not in kwargs: + return [val] + else: + return [] + + if "else" not in kwargs: + return ["else"] + + def conditional(self, **kwargs): + for i in range(0,NUM_IF_ELSE_NODES): + cond = "condition%d" % (i + 1) + if cond not in kwargs: + return [cond] + if kwargs.get(cond, False): + val = "value%d" % (i + 1) + return (kwargs.get(val, None),) + + return (kwargs.get("else", None),) + + +class InversionDemoLazyIndexSwitch: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 9, "step": 1}), + "value0": ("*", {"lazy": True}), + }, + "optional": { + "value1": ("*", {"lazy": True}), + "value2": ("*", {"lazy": True}), + "value3": ("*", {"lazy": True}), + "value4": ("*", {"lazy": True}), + "value5": ("*", {"lazy": True}), + "value6": ("*", {"lazy": True}), + "value7": ("*", {"lazy": True}), + "value8": ("*", {"lazy": True}), + "value9": ("*", {"lazy": True}), + } + } + + RETURN_TYPES = ("*",) + FUNCTION = "index_switch" + + CATEGORY = "InversionDemo Nodes/Logic" + + def check_lazy_status(self, index, **kwargs): + key = "value%d" % index + if key not in kwargs: + return [key] + + def index_switch(self, index, **kwargs): + key = "value%d" % index + return (kwargs[key],) + +class InversionDemoLazyMixImages: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "InversionDemo Nodes/Demo" + + def check_lazy_status(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + return (image1 * (1. - mask) + image2 * mask,) + +GENERAL_NODE_CLASS_MAPPINGS = { + "InversionDemoAdvancedPromptNode": InversionDemoAdvancedPromptNode, + "InversionDemoFakeAdvancedPromptNode": InversionDemoFakeAdvancedPromptNode, + "InversionDemoLazySwitch": InversionDemoLazySwitch, + "InversionDemoLazyIndexSwitch": InversionDemoLazyIndexSwitch, + "InversionDemoLazyMixImages": InversionDemoLazyMixImages, + "InversionDemoLazyConditional": InversionDemoLazyConditional, +} + +GENERAL_NODE_DISPLAY_NAME_MAPPINGS = { + "InversionDemoAdvancedPromptNode": "Advanced Prompt", + "InversionDemoFakeAdvancedPromptNode": "Fake Advanced Prompt", + "InversionDemoLazySwitch": "Lazy Switch", + "InversionDemoLazyIndexSwitch": "Lazy Index Switch", + "InversionDemoLazyMixImages": "Lazy Mix Images", + "InversionDemoLazyConditional": "Lazy Conditional", +} diff --git a/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py b/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py new file mode 100644 index 00000000000..598591551ff --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py @@ -0,0 +1,406 @@ +from comfy.graph_utils import GraphBuilder +import torch + +class AccumulateNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "to_add": ("*",), + }, + "optional": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + FUNCTION = "accumulate" + + CATEGORY = "InversionDemo Nodes/Lists" + + def accumulate(self, to_add, accumulation = None): + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return ({"accum": value},) + +class AccumulationHeadNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_head" + + CATEGORY = "InversionDemo Nodes/Lists" + + def accumulation_head(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (accumulation, None) + else: + return ({"accum": accum[1:]}, accum[0]) + +class AccumulationTailNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_tail" + + CATEGORY = "InversionDemo Nodes/Lists" + + def accumulation_tail(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (None, accumulation) + else: + return ({"accum": accum[:-1]}, accum[-1]) + +class AccumulationToListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("*",) + OUTPUT_IS_LIST = (True,) + + FUNCTION = "accumulation_to_list" + + CATEGORY = "InversionDemo Nodes/Lists" + + def accumulation_to_list(self, accumulation): + return (accumulation["accum"],) + +class ListToAccumulationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + INPUT_IS_LIST = (True,) + + FUNCTION = "list_to_accumulation" + + CATEGORY = "InversionDemo Nodes/Lists" + + def list_to_accumulation(self, list): + return ({"accum": list},) + +class AccumulationGetLengthNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("INT",) + + FUNCTION = "accumlength" + + CATEGORY = "InversionDemo Nodes/Lists" + + def accumlength(self, accumulation): + return (len(accumulation['accum']),) + +class AccumulationGetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}) + }, + } + + RETURN_TYPES = ("*",) + + FUNCTION = "get_item" + + CATEGORY = "InversionDemo Nodes/Lists" + + def get_item(self, accumulation, index): + return (accumulation['accum'][index],) + +class AccumulationSetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + + FUNCTION = "set_item" + + CATEGORY = "InversionDemo Nodes/Lists" + + def set_item(self, accumulation, index, value): + new_accum = accumulation['accum'][:] + new_accum[index] = value + return ({"accum": new_accum},) + +class IntMathOperation: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "int_math_operation" + + CATEGORY = "InversionDemo Nodes/Logic" + + def int_math_operation(self, a, b, operation): + if operation == "add": + return (a + b,) + elif operation == "subtract": + return (a - b,) + elif operation == "multiply": + return (a * b,) + elif operation == "divide": + return (a // b,) + elif operation == "modulo": + return (a % b,) + elif operation == "power": + return (a ** b,) + + +from .flow_control import NUM_FLOW_SOCKETS +class ForLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), + }, + "optional": { + "initial_value%d" % i: ("*",) for i in range(1, NUM_FLOW_SOCKETS) + }, + "hidden": { + "initial_value0": ("*",) + } + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["flow_control", "remaining"] + ["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_open" + + CATEGORY = "InversionDemo Nodes/Flow" + + def for_loop_open(self, remaining, **kwargs): + graph = GraphBuilder() + if "initial_value0" in kwargs: + remaining = kwargs["initial_value0"] + while_open = graph.node("WhileLoopOpen", condition=remaining, initial_value0=remaining, **{("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)] + return { + "result": tuple(["stub", remaining] + outputs), + "expand": graph.finalize(), + } + +class ForLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "old_remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1, "forceInput": True}), + }, + "optional": { + "initial_value%d" % i: ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + }, + } + + RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_close" + + CATEGORY = "InversionDemo Nodes/Flow" + + def for_loop_close(self, flow_control, old_remaining, **kwargs): + graph = GraphBuilder() + while_open = flow_control[0] + # TODO - Requires WAS-ns. Will definitely want to solve before merging + sub = graph.node("IntMathOperation", operation="subtract", a=[while_open,1], b=1) + cond = graph.node("ToBoolNode", value=sub.out(0)) + input_values = {("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)} + while_close = graph.node("WhileLoopClose", + flow_control=flow_control, + condition=cond.out(0), + initial_value0=sub.out(0), + **input_values) + return { + "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]), + "expand": graph.finalize(), + } + +class DebugPrint: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + "label": ("STRING", {"multiline": False}), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "debug_print" + + CATEGORY = "InversionDemo Nodes/Debug" + + def debugtype(self, value): + if isinstance(value, list): + result = "[" + for i, v in enumerate(value): + result += (self.debugtype(v) + ",") + result += "]" + elif isinstance(value, tuple): + result = "(" + for i, v in enumerate(value): + result += (self.debugtype(v) + ",") + result += ")" + elif isinstance(value, dict): + result = "{" + for k, v in value.items(): + result += ("%s: %s," % (self.debugtype(k), self.debugtype(v))) + result += "}" + elif isinstance(value, str): + result = "'%s'" % value + elif isinstance(value, bool) or isinstance(value, int) or isinstance(value, float): + result = str(value) + elif isinstance(value, torch.Tensor): + result = "Tensor[%s]" % str(value.shape) + else: + result = type(value).__name__ + return result + + def debug_print(self, value, label): + print("[%s]: %s" % (label, self.debugtype(value))) + return (value,) + +NUM_LIST_SOCKETS = 10 +class MakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + "value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "InversionDemo Nodes/Lists" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if "value%d" % i in kwargs: + result.append(kwargs["value%d" % i]) + return (result,) + +UTILITY_NODE_CLASS_MAPPINGS = { + "AccumulateNode": AccumulateNode, + "AccumulationHeadNode": AccumulationHeadNode, + "AccumulationTailNode": AccumulationTailNode, + "AccumulationToListNode": AccumulationToListNode, + "ListToAccumulationNode": ListToAccumulationNode, + "AccumulationGetLengthNode": AccumulationGetLengthNode, + "AccumulationGetItemNode": AccumulationGetItemNode, + "AccumulationSetItemNode": AccumulationSetItemNode, + "ForLoopOpen": ForLoopOpen, + "ForLoopClose": ForLoopClose, + "IntMathOperation": IntMathOperation, + "DebugPrint": DebugPrint, + "MakeListNode": MakeListNode, +} +UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { + "AccumulateNode": "Accumulate", + "AccumulationHeadNode": "Accumulation Head", + "AccumulationTailNode": "Accumulation Tail", + "AccumulationToListNode": "Accumulation to List", + "ListToAccumulationNode": "List to Accumulation", + "AccumulationGetLengthNode": "Accumulation Get Length", + "AccumulationGetItemNode": "Accumulation Get Item", + "AccumulationSetItemNode": "Accumulation Set Item", + "ForLoopOpen": "For Loop Open", + "ForLoopClose": "For Loop Close", + "IntMathOperation": "Int Math Operation", + "DebugPrint": "Debug Print", + "MakeListNode": "Make List", +} diff --git a/execution.py b/execution.py index 00908eadd46..e0a50bb4175 100644 --- a/execution.py +++ b/execution.py @@ -4,6 +4,7 @@ import threading import heapq import traceback +from enum import Enum import inspect from typing import List, Literal, NamedTuple, Optional @@ -11,29 +12,97 @@ import nodes import comfy.model_management +import comfy.graph_utils +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy.graph_utils import is_link, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID + +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + SLEEPING = 2 + +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): + self.dynprompt = dynprompt + self.outputs_cache = outputs_cache + self.is_changed = {} + + def get(self, node_id): + if node_id not in self.is_changed: + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, "IS_CHANGED"): + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + else: + input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + self.is_changed[node_id] = node["is_changed"] + except: + node["is_changed"] = float("NaN") + self.is_changed[node_id] = node["is_changed"] + else: + self.is_changed[node_id] = False + return self.is_changed[node_id] -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] + + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.objects = HierarchicalCache(CacheKeySetID) + + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result + +def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + if is_link(input_data) and not input_info.get("rawLink", False): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: - input_data_all[x] = (None,) + if outputs is None: + continue # This might be a lazily-evaluated input + cached_output = outputs.get(input_unique_id) + if cached_output is None: continue - obj = outputs[input_unique_id][output_index] + obj = cached_output[output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] @@ -41,7 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -63,51 +132,97 @@ def slice_dict(d, i): if input_is_list: if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) + execution_block = None + for k, v in input_data_all.items(): + for input in v: + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(0) + results.append(getattr(obj, func)(**input_data_all)) + else: + results.append(execution_block) elif max_len_input == 0: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)()) - else: + else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + execution_block = None + for k, v in input_dict.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(i) + results.append(getattr(obj, func)(**input_dict)) + else: + results.append(execution_block) return results -def get_output_data(obj, input_data_all): +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + subgraph_results = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif 'result' in r: + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x): if x is None: @@ -117,53 +232,144 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): +def non_recursive_execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return (True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if is_link(r): + source_node, source_output = r[0], r[1] + node_output = caches.outputs.get(source_node)[source_output] + for o in node_output: + resolved_output.append(o) - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data) + if server.client_id is not None: + server.last_node_id = display_node_id + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + + obj = caches.objects.get(unique_id) + if obj is None: + obj = class_def() + caches.objects.set(unique_id, obj) + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.SLEEPING, None, None) + + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": "Execution Blocked: %s" % block.message, + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: + cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.get_node(node_id) is not None: + raise Exception("Attempt to add duplicate node %s" % node_id) + break + for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for i in range(len(node_outputs)): + if is_link(node_outputs[i]): + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + new_output_links.append((from_node_id, from_socket)) + cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.SLEEPING, None, None) + caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { - "node_id": unique_id, + "node_id": real_node_id, } - return (False, error_details, iex) + return (ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -174,108 +380,43 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_data_formatted[name] = [format_value(x) for x in inputs] output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + # TODO - Implement me + # for node_id, node_outputs in outputs.items(): + # output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] logging.error("!!! Exception during processing !!!") logging.error(traceback.format_exc()) error_details = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } - return (False, error_details, ex) + return (ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return (True, None, None) - -def recursive_will_execute(prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) - - return will_execute + [unique_id] - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True + return (ExecutionResult.SUCCESS, None, None) - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete +CACHE_FOR_DEBUG_DUMP = None +def dump_cache_for_debug(): + return CACHE_FOR_DEBUG_DUMP.recursive_debug_dump() class PromptExecutor: - def __init__(self, server): + def __init__(self, server, lru_size=None): + self.lru_size = lru_size self.server = server self.reset() def reset(self): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} + self.caches = CacheSet(self.lru_size) + global CACHE_FOR_DEBUG_DUMP + CACHE_FOR_DEBUG_DUMP = self.caches self.status_messages = [] self.success = True - self.old_prompt = {} def add_message(self, event, data, broadcast: bool): self.status_messages.append((event, data)) @@ -302,7 +443,6 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e "node_id": node_id, "node_type": class_type, "executed": list(executed), - "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], @@ -311,18 +451,6 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e } self.add_message("execution_error", mes, broadcast=False) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -335,61 +463,45 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d - - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() + + current_outputs = self.caches.outputs.all_node_ids() comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) + pending_subgraph_results = {} executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) for node_id in list(execute_outputs): - to_execute += [(0, node_id)] - - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if self.success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) - break + execution_list.add_node(node_id) - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + while not execution_list.is_empty(): + node_id = execution_list.stage_node_execution() + result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.SLEEPING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + + ui_outputs = {} + meta_outputs = {} + for ui_info in self.caches.ui.all_active_values(): + node_id = ui_info["meta"]["node_id"] + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() @@ -406,7 +518,7 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True @@ -415,22 +527,23 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -449,7 +562,7 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: + if type_input != "*" and r[val[1]] != "*" and r[val[1]] != type_input: received_type = r[val[1]] details = f"{x}, {received_type} != {type_input}" error = { @@ -501,6 +614,9 @@ def validate_inputs(prompt, item, validated): if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "BOOLEAN": + val = bool(val) + inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", @@ -516,33 +632,32 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue + } + errors.append(error) + continue if x not in validate_function_inputs: if isinstance(type_input, list): @@ -582,7 +697,7 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}" @@ -741,7 +856,7 @@ class ExecutionStatus(NamedTuple): completed: bool messages: List[str] - def task_done(self, item_id, outputs, + def task_done(self, item_id, history_result, status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) @@ -754,9 +869,10 @@ def task_done(self, item_id, outputs, self.history[prompt[1]] = { "prompt": prompt, - "outputs": copy.deepcopy(outputs), + "outputs": {}, 'status': status_dict, } + self.history[prompt[1]].update(history_result) self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 69d9bce6cb7..8cd869e4885 100644 --- a/main.py +++ b/main.py @@ -91,7 +91,7 @@ def cuda_malloc_warning(): print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = execution.PromptExecutor(server, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -111,7 +111,7 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, - e.outputs_ui, + e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, diff --git a/server.py b/server.py index dca06f6fc32..c935505879c 100644 --- a/server.py +++ b/server.py @@ -396,6 +396,7 @@ def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -451,6 +452,22 @@ async def get_queue(request): queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) + @routes.get("/debugcache") + async def get_debugcache(request): + def custom_serialize(obj): + from comfy.caching import Unhashable + if isinstance(obj, frozenset): + try: + return dict(obj) + except: + return list(obj) + elif isinstance(obj, Unhashable): + return "NaN" + return str(obj) + def custom_dump(obj): + return json.dumps(obj, default=custom_serialize) + return web.json_response(execution.dump_cache_for_debug(), dumps=custom_dump) + @routes.post("/prompt") async def post_prompt(request): print("got prompt") @@ -632,6 +649,9 @@ async def start(self, address, port, verbose=True, call_on_start=None): site = web.TCPSite(runner, address, port) await site.start() + self.address = address + self.port = port + if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index e6ebedd9150..15b784d6768 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -443,6 +443,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${nodes.save.id}`, + display_node: `${nodes.save.id}`, output: { images: [ { @@ -483,6 +484,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${group.id}:5`, + display_node: `${group.id}:5`, output: { images: [ { diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 0f041fcd2f9..b78d33aac7c 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -956,8 +956,8 @@ export class GroupNodeHandler { const executed = handleEvent.call( this, "executed", - (d) => d?.node, - (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + (d) => d?.display_node, + (d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution }) ); const onRemoved = node.onRemoved; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3f1c1f8c126..f89c731e6bb 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js"; import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); const TARGET = Symbol(); // Used for reroutes to specify the real target widget diff --git a/web/scripts/api.js b/web/scripts/api.js index 3a9bcc87a4e..ae3fbd13a01 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -126,7 +126,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 6df393ba60d..d1687845438 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1255,7 +1255,7 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - const output = this.nodeOutputs[detail.node]; + const output = this.nodeOutputs[detail.display_node]; if (detail.merge && output) { for (const k in detail.output ?? {}) { const v = output[k]; @@ -1266,9 +1266,9 @@ export class ComfyApp { } } } else { - this.nodeOutputs[detail.node] = detail.output; + this.nodeOutputs[detail.display_node] = detail.output; } - const node = this.graph.getNodeById(detail.node); + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d4835c6e445..d69434993b0 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -227,7 +227,14 @@ class ComfyList { onclick: async () => { await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + if (item.meta && item.meta[key] && item.meta[key].display_node) { + app.nodeOutputs[item.meta[key].display_node] = value; + } else { + app.nodeOutputs[key] = value; + } + } } }, }),