Skip to content

Commit

Permalink
Reimplement something much closer to TreeReg as an experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Feb 7, 2025
1 parent 4872e30 commit 1aaed0c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 37 deletions.
17 changes: 12 additions & 5 deletions stanza/models/constituency/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from stanza.models.common import utils
from stanza.models.constituency import transition_sequence
from stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent
from stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent, Shift
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.state import State
from stanza.models.constituency.tree_stack import TreeStack
Expand Down Expand Up @@ -274,7 +274,7 @@ def build_batch_from_tagged_words(self, batch_size, data_iterator):
return state_batch


def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False, keep_words=False):
"""
Repeat transitions to build a list of trees from the input batches.
Expand All @@ -299,7 +299,7 @@ def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_
batch_indices = list(range(len(state_batch)))
horizon_iterator = iter([])

if keep_constituents:
if keep_constituents or keep_words:
constituents = defaultdict(list)

while len(state_batch) > 0:
Expand All @@ -315,6 +315,13 @@ def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_
# constituents.value is the TreeStack node
# constituents.value.value is the Constituent itself (with the tree and the embedding)
constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)
if keep_words:
for t_idx, transition in enumerate(transitions):
if isinstance(transition, Shift):
# constituents is a TreeStack with information on how to build the next state of the LSTM or attn
# constituents.value is the TreeStack node
# constituents.value.value is the Constituent itself (with the tree and the embedding)
constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)

remove = set()
for idx, state in enumerate(state_batch):
Expand Down Expand Up @@ -363,7 +370,7 @@ def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, tra
with torch.no_grad():
return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)

def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True):
def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True, keep_words=False):
"""
Return a ParseResult for each tree in the trees list
Expand All @@ -377,7 +384,7 @@ def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituen
# TODO: refactor?
batch_size = self.args['eval_batch_size']
tree_iterator = iter(trees)
treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores)
treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores, keep_words=keep_words)
return treebank

def parse_tagged_words(self, words, batch_size):
Expand Down
37 changes: 36 additions & 1 deletion stanza/models/constituency/parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def count_wide_neighbors(self):

def move_first_wide_neighbor(self):
already_moved = False
subtrees = []

def move_helper(tree):
if tree.is_leaf():
Expand All @@ -633,6 +634,11 @@ def move_helper(tree):
new_children.append(Tree(child.label, left_children))
new_children.append(Tree(next_child.label, right_children))

subtrees.append(child)
subtrees.append(next_child)
subtrees.append(new_children[-2])
subtrees.append(new_children[-1])

if child_idx + 2 < len(tree.children):
new_children.extend(move_helper(x) for x in tree.children[child_idx + 2:])

Expand All @@ -646,11 +652,40 @@ def move_helper(tree):
new_children.append(Tree(child.label, left_children))
new_children.append(Tree(next_child.label, right_children))

subtrees.append(child)
subtrees.append(next_child)
subtrees.append(new_children[-2])
subtrees.append(new_children[-1])

if child_idx + 2 < len(tree.children):
new_children.extend(move_helper(x) for x in tree.children[child_idx + 2:])

return Tree(tree.label, new_children)
new_children.append(move_helper(child))
return Tree(tree.label, new_children)

return move_helper(self)
return move_helper(self), subtrees

def mark_parents(self, parent=None):
self.parent = parent

for child_idx, child in enumerate(self.children):
child.mark_parents(parent=self)
child.child_index = child_idx

def find_previous_span(self):
if self.parent is None:
return None
if self.child_index == 0:
return self.parent.find_previous_span()
else:
return self.parent.children[self.child_index - 1]

def find_next_span(self):
if self.parent is None:
return None
if self.child_index == len(self.parent.children) - 1:
return self.parent.find_next_span()
else:
return self.parent.children[self.child_index + 1]

80 changes: 51 additions & 29 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,35 +685,57 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
orthogonal_loss = 0.0
wide_neighbors = 0
if epoch >= args['orthogonal_initial_epoch'] and orthogonal_loss_function is not None:
wide_neighbors = sum(x.tree.count_wide_neighbors() for x in training_batch)

gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)
orthogonal_losses = []
def build_losses(con_values, tree):
# this can skip preterminals
# but a preterminal in the middle of a phrase has a high chance of being
# a conjunction, a punctuation, or other non-function word anyway
subtrees = [x for x in tree.children if not x.is_preterminal()]
for subtree in subtrees:
build_losses(con_values, subtree)
for left, right in itertools.combinations(subtrees, 2):
left = str(left)
right = str(right)
if left in con_values and right in con_values:
left_value = con_values[left].squeeze(0)
right_value = con_values[right].squeeze(0)
mse = torch.dot(left_value, right_value) / (len(subtrees) - 1)
orthogonal_losses.append(mse)
for result in gold_results:
gold_constituents = result.constituents
con_values = {}
for con in gold_constituents:
# normalize so that we are enforcing only the angle go to 0, not the length
con_values[str(con.value)] = nn.functional.normalize(con.tree_hx)
build_losses(con_values, result.gold)
orthogonal_losses = torch.stack(orthogonal_losses)
orthogonal_target = torch.zeros(orthogonal_losses.shape).to(orthogonal_losses.device)
orthogonal_loss = orthogonal_loss_function(orthogonal_losses, orthogonal_target) * args['orthogonal_learning_rate']
all_trees = [(x.tree.move_first_wide_neighbor(), x.tree) for x in training_batch if x.tree.count_wide_neighbors() > 0]
mutated_trees = [x[0] for x in all_trees]
gold_trees = [x[1] for x in all_trees]
wide_neighbors = len(mutated_trees)
if wide_neighbors > 0:
mutated_results = model.analyze_trees([x[0] for x in mutated_trees], keep_constituents=True, keep_scores=False, keep_words=True)
gold_results = model.analyze_trees(gold_trees, keep_constituents=True, keep_scores=False, keep_words=True)

orthogonal_losses = []

def orth(x, y):
if x is None or y is None:
return 0
orth_component = x - torch.dot(x, y) * y
return torch.linalg.norm(orth_component)

def span_independence(tree, subtree, tree_hx):
"""
SCIN from the TreeReg paper
"""
tree.mark_parents()
previous_span = subtree.find_previous_span()
previous_hx = tree_hx[str(previous_span)] if previous_span is not None else None
next_span = subtree.find_next_span()
next_hx = tree_hx[str(next_span)] if next_span is not None else None
current_hx = tree_hx[str(subtree)]
return orth(next_hx, current_hx) + orth(previous_hx, current_hx)

def split_span_score(tree, left_subtree, right_subtree, tree_hx):
return span_independence(tree, left_subtree, tree_hx) + span_independence(tree, right_subtree, tree_hx)

def build_losses(gold_tree, mutated_tree, gold_left, gold_right, mutated_left, mutated_right, gold_hx, mutated_hx):
# +4 to make the split_span_score for the gold tree always positive, trending to 0 if perfectly orthogonal
loss = split_span_score(mutated_tree, mutated_left, mutated_right, mutated_hx) + 4 - split_span_score(gold_tree, gold_left, gold_right, gold_hx)
return loss

for mutated_result, (mutated_tree, subtrees), gold_result, gold_tree in zip(mutated_results, mutated_trees, gold_results, gold_trees):
mutated_constituents = mutated_result.constituents
mutated_hx = {}
for con in mutated_constituents:
mutated_hx[str(con.value)] = nn.functional.normalize(con.tree_hx).squeeze(0)

gold_constituents = gold_result.constituents
gold_hx = {}
for con in gold_constituents:
gold_hx[str(con.value)] = nn.functional.normalize(con.tree_hx).squeeze(0)

gold_left, gold_right, mutated_left, mutated_right = subtrees
orthogonal_losses.append(build_losses(gold_tree, mutated_tree, gold_left, gold_right, mutated_left, mutated_right, gold_hx, mutated_hx))

orthogonal_loss = sum(orthogonal_losses) * args['orthogonal_learning_rate']


errors = process_outputs(errors)
Expand Down
4 changes: 2 additions & 2 deletions stanza/tests/constituency/test_parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_count_wide_nodes():

def test_count_wide_nodes():
tree = read_one_tree(WIDE_NEIGHBORS_TREE)
new_tree = tree.move_first_wide_neighbor()
new_tree, _ = tree.move_first_wide_neighbor()

expected_move = """
(ROOT
Expand Down Expand Up @@ -502,7 +502,7 @@ def test_count_wide_nodes():
(NP (NNP Windy) (NNP City)))))
"""
smaller_tree = read_one_tree(smaller_tree)
new_tree = smaller_tree.move_first_wide_neighbor()
new_tree, _ = smaller_tree.move_first_wide_neighbor()
expected_smaller_tree = """
(S-HLN
(NP-SBJ (VBN Revitalized) (NNS Classics) (VBP Take))
Expand Down

0 comments on commit 1aaed0c

Please sign in to comment.