Skip to content

Commit

Permalink
Use the output_layers ... TBD
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Feb 4, 2025
1 parent 81e6e26 commit f6c262d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
missing_node_errors = []
if epoch <= args['contrastive_final_epoch'] and epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True, keep_constituents=True)
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False, keep_output_layers=True)
for reparsed_result, gold_result in zip(reparsed_results, gold_results):
reparsed_state = reparsed_result.state
reparsed_tree = reparsed_state.constituents.value.value.value
Expand All @@ -628,7 +628,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te

if common_missing_nodes:
synthetic_trees = [x.tree.flip_missing_node_errors(common_missing_nodes) for x in training_batch]
reparsed_results = model.analyze_trees(synthetic_trees, keep_constituents=True, keep_scores=False)
reparsed_results = model.analyze_trees(synthetic_trees, keep_constituents=True, keep_scores=False, keep_output_layers=True)

reparsed_negatives = []
gold_negatives = []
Expand Down

0 comments on commit f6c262d

Please sign in to comment.