Skip to content

Commit

Permalink
Merge pull request #148 from swansonk14/quote-docstrings
Browse files Browse the repository at this point in the history
Fixing issues with comment extraction for the help string
  • Loading branch information
swansonk14 authored Aug 24, 2024
2 parents 5257fe8 + 3980c81 commit c0d4b75
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 15 deletions.
97 changes: 82 additions & 15 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import ArgumentParser, ArgumentTypeError
import ast
from base64 import b64encode, b64decode
import copy
from functools import wraps
Expand All @@ -10,6 +11,7 @@
import re
import subprocess
import sys
import textwrap
import tokenize
from typing import (
Any,
Expand All @@ -20,10 +22,12 @@
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
import warnings

if sys.version_info >= (3, 10):
from types import UnionType
Expand Down Expand Up @@ -184,7 +188,6 @@ def tokenize_source(obj: object) -> Generator:
"""Returns a generator for the tokens of the object's source code."""
source = inspect.getsource(obj)
token_generator = tokenize.generate_tokens(StringIO(source).readline)

return token_generator


Expand All @@ -204,21 +207,65 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
line_to_tokens.setdefault(start_line, []).append(
{
"token_type": token_type,
"token": token,
"start_line": start_line,
"start_column": start_column,
"end_line": end_line,
"end_column": end_column,
"line": line,
}
)
line_to_tokens.setdefault(start_line, []).append({
'token_type': token_type,
'token': token,
'start_line': start_line,
'start_column': start_column,
'end_line': end_line,
'end_column': end_column,
'line': line
})

return line_to_tokens


def get_subsequent_assign_lines(cls: type) -> Set[int]:
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
# Get source code of class
source = inspect.getsource(cls)

# Parse source code using ast (with an if statement to avoid indentation errors)
source = f"if True:\n{textwrap.indent(source, ' ')}"
body = ast.parse(source).body[0]

# Set up warning message
parse_warning = (
"Could not parse class source code to extract comments. "
"Comments in the help string may be incorrect."
)

# Check for correct parsing
if not isinstance(body, ast.If):
warnings.warn(parse_warning)
return set()

# Extract if body
if_body = body.body

# Check for a single body
if len(if_body) != 1:
warnings.warn(parse_warning)
return set()

# Extract class body
cls_body = if_body[0]

# Check for a single class definition
if not isinstance(cls_body, ast.ClassDef):
warnings.warn(parse_warning)
return set()

# Get line numbers of assign statements
assign_lines = set()
for node in cls_body.body:
if isinstance(node, (ast.Assign, ast.AnnAssign)):
# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
assign_lines |= set(range(node.lineno, node.end_lineno))

return assign_lines


def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
# Get mapping from line number to tokens
Expand All @@ -227,12 +274,19 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
# Get class variable column number
class_variable_column = get_class_column(cls)

# For all multiline assign statements, get the line numbers after the first line of the assignment
# This is used to avoid identifying comments in multiline assign statements
subsequent_assign_lines = get_subsequent_assign_lines(cls)

# Extract class variables
class_variable = None
variable_to_comment = {}
for tokens in line_to_tokens.values():
for i, token in enumerate(tokens):
for line, tokens in line_to_tokens.items():
# Skip assign lines after the first line of multiline assign statements
if line in subsequent_assign_lines:
continue

for i, token in enumerate(tokens):
# Skip whitespace
if token["token"].strip() == "":
continue
Expand All @@ -244,8 +298,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
and token["token"][:1] in {'"', "'"}
):
sep = " " if variable_to_comment[class_variable]["comment"] else ""

# Identify the quote character (single or double)
quote_char = token["token"][:1]
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip(quote_char).strip()

# Identify the number of quote characters at the start of the string
num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char))

# Remove the number of quote characters at the start of the string and the end of the string
token["token"] = token["token"][num_quote_chars:-num_quote_chars]

# Remove the unicode escape sequences (e.g. "\"")
token["token"] = bytes(token["token"], encoding='ascii').decode('unicode-escape')

# Add the token to the comment, stripping whitespace
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()

# Match class variable
class_variable = None
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,30 @@ class TripleQuoteMultiline:
class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}}
self.assertEqual(get_class_variables(TripleQuoteMultiline), class_variables)

def test_comments_with_quotes(self):
class MultiquoteMultiline:
bar: int = 0
'\'\'biz baz\''

hi: str
"\"Hello there\"\""

class_variables = {}
class_variables['bar'] = {'comment': "''biz baz'"}
class_variables['hi'] = {'comment': '"Hello there""'}
self.assertEqual(get_class_variables(MultiquoteMultiline), class_variables)

def test_multiline_argument(self):
class MultilineArgument:
bar: str = (
"This is a multiline argument"
" that should not be included in the docstring"
)
"""biz baz"""

class_variables = {"bar": {"comment": "biz baz"}}
self.assertEqual(get_class_variables(MultilineArgument), class_variables)

def test_single_quote_multiline(self):
class SingleQuoteMultiline:
bar: int = 0
Expand Down

0 comments on commit c0d4b75

Please sign in to comment.