-# ©2021 Dustin Walde
+# ©2021-2022 Dustin Walde
# Licensed under GPL-3.0
import sys
INPUT_SYMBOL = '+'
OUTPUT_SYMBOL = '-'
SOCKET_DELIMETER = ':'
-PATH_DELIMETER = '/'
def is_group(node: Node) -> bool:
"""
- Check if node is a *GroupNode of some sort
+ Check if node is a *GroupNode of some sort.
"""
return _is_union_instance(node, NodeGroupType)
-def build_path(*nodes: str,
+def format_node(node: str,
input: Union[bool,str]=False,
output: Union[bool,str]=False) \
-> str:
if input != False and output != False:
- raise ValueError("Path cannot be both input and output")
-
- node_path = PATH_DELIMETER.join(nodes)
-
- def _format_socket_path(socket, symbol):
- if isinstance(socket, str):
- return '{}{}{}{}'.format(
- symbol, node_path, SOCKET_DELIMETER, socket)
-
- return '{}{}'.format(symbol, node_path)
+ raise ValueError("Node ref cannot be both input and output")
if input != False:
- return _format_socket_path(input, INPUT_SYMBOL)
+ return _format_socket_str(node, input, INPUT_SYMBOL)
elif output != False:
- return _format_socket_path(output, OUTPUT_SYMBOL)
+ return _format_socket_str(node, output, OUTPUT_SYMBOL)
- return node_path
+ return node
def link_nodes(
tree: NodeTree,
"""
Link from output node to input node.
If no sockets are specified, this will match as many as it can find.
- Will not create new links to input sockets if preserve_existing is
- set.
+ If multiple inputs or outputs are set, matches will only be set for
+ matching types.
+ Adding a connection that changes the type requires connecting
+ specific sockets (one out to one in).
+ Will not create new links from output or to input sockets if
+ `preserve_existing` is set.
"""
connections = 0
output_sockets = []
input_sockets = []
- parent_tree = tree
-
+ # get output sockets
if isinstance(link_from, NodeSocket):
if link_from.is_output:
output_sockets.append(link_from)
for output in link_from.outputs:
output_sockets.append(output)
else:
- parent_tree, _, _, output_sockets = resolve_path(tree, link_from)
+ _, _, output_sockets = resolve_node(tree, link_from)
+
+ output_sockets = _label_sockets(output_sockets, preserve_existing)
+ # get input sockets
if isinstance(link_to, NodeSocket):
if not link_to.is_output:
input_sockets.append(link_to)
for input in link_to.inputs:
input_sockets.append(input)
else:
- parent_tree, _, input_sockets, _ = resolve_path(tree, link_to)
+ _, input_sockets, _ = resolve_node(tree, link_to)
- for i in range(len(input_sockets)):
- input_sockets[i] = [
- input_sockets[i],
- preserve_existing and len(input_sockets[i].links) > 0]
+ input_sockets = _label_sockets(input_sockets, preserve_existing)
- for output in output_sockets:
+ for output_data in output_sockets:
+ output, out_matched = output_data
+ if out_matched:
+ continue
for input_data in input_sockets:
input, matched = input_data
if matched:
if output.type == input.type \
or (len(output_sockets) == 1 and len(input_sockets) == 1):
- parent_tree.links.new(output, input)
+ tree.links.new(output, input)
input_data[1] = True
connections += 1
break
for output in of.outputs:
links.extend(output.links)
else:
- _, _, inputs, outputs = resolve_path(tree, of)
+ _, inputs, outputs = resolve_node(tree, of)
for input in inputs:
links.extend(input.links)
for output in outputs:
def remove_links(tree: NodeTree, node: NodeType) -> int:
removed = 0
- if isinstance(node, str):
- node_tree, node, inputs, outputs = resolve_path(tree, node)
- for input in inputs:
- for link in input.links:
- node_tree.links.remove(link)
- removed += 1
- for output in outputs:
- for link in output.links:
- node_tree.links.remove(link)
- removed += 1
- else:
- for link in get_links(tree, node):
- tree.links.remove(link)
- removed += 1
+ for link in get_links(tree, node):
+ tree.links.remove(link)
+ removed += 1
return removed
-def get_socket(tree: NodeTree, node_path: str) -> NodeSocket:
- return get_sockets(tree, node_path)[0]
+def get_socket(tree: NodeTree, node_str: str) -> NodeSocket:
+ return get_sockets(tree, node_str)[0]
-def get_sockets(tree: NodeTree, node_path: str) -> List[NodeSocket]:
- _, _, inputs, outputs = resolve_path(tree, node_path)
+def get_sockets(tree: NodeTree, node_str: str) -> List[NodeSocket]:
+ _, inputs, outputs = resolve_node(tree, node_str)
inputs.extend(outputs)
return inputs
-def get_tree(tree: NodeTree, node_path: str) -> NodeTree:
- return resolve_path(tree, node_path)[0]
+def get_input_sockets(tree: NodeTree, node_str: str) -> List[NodeSocket]:
+ return resolve_node(tree, node_str)[1]
-def get_node(tree: NodeTree, node_path: str) -> Node:
- """
- Retrieve the the Node from tree at the given node_path.
+def get_output_sockets(tree: NodeTree, node_str: str) -> List[NodeSocket]:
+ return resolve_node(tree, node_str)[2]
- raises ValueError if node_path is invalid.
+def get_node(tree: NodeTree, node_str: str) -> Node:
"""
- return resolve_path(tree, node_path)[1]
-
-def resolve_path(tree: NodeTree, node_path: str) \
- -> Tuple[NodeTree, Node, List[NodeSocket], List[NodeSocket]]:
+ Retrieve the the Node from tree at the given `node_str`.
- path_parts, socket_name, socket_index, include_inputs, include_outputs = \
- _parse_path(node_path)
-
- current_tree = tree
- for node in path_parts[:-1]:
- if node not in current_tree.nodes:
- _invalid_path(node_path,
- "Node name ({}) not found".format(node))
-
- current_tree = current_tree.nodes[node]
+ raises ValueError if `node_path` is invalid.
+ """
+ return resolve_node(tree, node_str)[0]
- if not is_group(current_tree):
- _invalid_path(node_path,
- "Parent node ({}) is not a group.".format(node))
+def resolve_node(tree: NodeTree, node_str: str) \
+ -> Tuple[Node, List[NodeSocket], List[NodeSocket]]:
+ """
+ In the context of `tree`, get the Node and input and/or output sockets
+ of the `node_str`.
- current_tree = current_tree.node_tree
+ Return:
+ Tuple containing (Node, `Socket`s, output `Socket`s)
+ """
- leaf_node = path_parts[-1]
+ node_name, socket_name, socket_index, include_inputs, include_outputs = \
+ _parse_node_str(node_str)
- if leaf_node not in current_tree.nodes:
- _invalid_path(node_path,
- "Node name ({}) not found".format(leaf_node))
+ if node_name not in tree.nodes:
+ _invalid_node_str(node_str,
+ "Node name ({}) not found".format(node_name))
- node = current_tree.nodes[leaf_node]
+ node = tree.nodes[node_name]
inputs = []
outputs = []
- def _append_matching_sockets(
- include: bool,
- node_sockets: Union[NodeInputs,NodeOutputs],
- out_sockets: List[NodeSocket]):
- name_count = 0
- if include:
- for socket in node_sockets:
- if socket_name is None or socket_name == socket.name:
- if socket_index < 0 or socket_index == name_count:
- out_sockets.append(socket)
- name_count += 1
+ _append_matching_sockets(socket_name, socket_index, include_inputs, node.inputs, inputs)
+ _append_matching_sockets(socket_name, socket_index, include_outputs, node.outputs, outputs)
- _append_matching_sockets(include_inputs, node.inputs, inputs)
- _append_matching_sockets(include_outputs, node.outputs, outputs)
-
- return (current_tree, node, inputs, outputs)
+ return (node, inputs, outputs)
# helpers ↓
-def _invalid_path(node_path, message):
- raise ValueError("Invalid Path: {}\n{}".format(node_path, message))
+def _append_matching_sockets(
+ socket_name: str,
+ socket_index: str,
+ include: bool,
+ node_sockets: Union[NodeInputs,NodeOutputs],
+ out_sockets: List[NodeSocket]):
+ name_count = 0
+ if include:
+ for socket in node_sockets:
+ if socket_name is None or socket_name == socket.name:
+ if socket_index < 0 or socket_index == name_count:
+ out_sockets.append(socket)
+ name_count += 1
+
+def _format_socket_str(node, socket, symbol):
+ if isinstance(socket, str):
+ return '{}{}{}{}'.format(
+ symbol, node, SOCKET_DELIMETER, socket)
+
+ return '{}{}'.format(symbol, node)
+
+def _invalid_node_str(node_str, message):
+ raise ValueError("Invalid Node Str: {}\n{}".format(node_str, message))
def _is_union_instance(item, union_type) -> bool:
+ """
+ Check if `item` is of any of the types contained in a Union
+ definition `union_type`.
+ """
if not get_origin(union_type) is Union:
raise TypeError("Expected Union type")
return isinstance(item, get_args(union_type))
-def _parse_path(node_path: str) -> Tuple[List[str], str, int, bool, bool]:
+def _label_sockets(sockets: List[NodeSocket], preserve_existing: bool)\
+ -> List[Tuple[NodeSocket,bool]]:
+ out: List[Tuple[NodeSocket,bool]] = []
+ for i in range(len(sockets)):
+ out.append((
+ sockets[i],
+ preserve_existing and len(sockets[i].links) > 0))
+ return out
+
+def _parse_node_str(node_str: str) -> Tuple[str, str, int, bool, bool]:
input = True
output = True
socket = None
socket_index = -1
- path_parts = None
+ node = None
- if node_path[0] == INPUT_SYMBOL:
+ if node_str[0] == INPUT_SYMBOL:
output = False
- node_path = node_path[1:]
- elif node_path[0] == OUTPUT_SYMBOL:
+ node_str = node_str[1:]
+ elif node_str[0] == OUTPUT_SYMBOL:
input = False
- node_path = node_path[1:]
+ node_str = node_str[1:]
- parts = node_path.split(SOCKET_DELIMETER)
+ parts = node_str.split(SOCKET_DELIMETER)
if len(parts) > 1:
socket = parts[-1]
- path_parts = SOCKET_DELIMETER.join(parts[:-1])\
- .split(PATH_DELIMETER)
+ node = SOCKET_DELIMETER.join(parts[:-1])
else:
- path_parts = node_path.split(PATH_DELIMETER)
+ node = node_str
if socket is not None:
left_bracket = socket.find('[')
socket_index = int(socket[left_bracket+1:right_bracket])
socket = socket[:left_bracket]
- return (path_parts, socket, socket_index, input, output)
\ No newline at end of file
+ return (node, socket, socket_index, input, output)
\ No newline at end of file