diff --git a/scripts/nb-tester/qiskit_docs_notebook_tester/execute.py b/scripts/nb-tester/qiskit_docs_notebook_tester/execute.py index 54ddac4c63..c0d91948eb 100644 --- a/scripts/nb-tester/qiskit_docs_notebook_tester/execute.py +++ b/scripts/nb-tester/qiskit_docs_notebook_tester/execute.py @@ -75,15 +75,22 @@ async def execute_notebook(job: NotebookJob) -> Result: print(f"▶️ Executing {job.path}") working_directory = tempfile.TemporaryDirectory() - possible_exceptions = ( + execution_exceptions = ( nbclient.exceptions.CellExecutionError, nbclient.exceptions.CellTimeoutError, ) try: nb = await _execute_notebook(job, working_directory.name) - except possible_exceptions as err: + except execution_exceptions as err: print(f"❌ Problem in {job.path}:\n{err}") return Result(False, reason="Exception in notebook") + except SyntaxError as err: + print( + f"❌ Problem in {job.path}:\n" + f"Error parsing code while post-processing (line {err.lineno}):\n" + f" {err.text}" + ) + return Result(False, reason="Invalid syntax") finally: working_directory.cleanup() diff --git a/scripts/nb-tester/qiskit_docs_notebook_tester/post_process.py b/scripts/nb-tester/qiskit_docs_notebook_tester/post_process.py index 8bcc0f3464..06d33a5b0c 100644 --- a/scripts/nb-tester/qiskit_docs_notebook_tester/post_process.py +++ b/scripts/nb-tester/qiskit_docs_notebook_tester/post_process.py @@ -11,12 +11,13 @@ # been altered from the originals. import ast -import re -import sys import importlib import itertools +import re +import sys from pathlib import Path -from typing_extensions import Iterable +from collections.abc import Iterable + import nbformat from squeaky import clean_notebook @@ -56,11 +57,11 @@ def get_package_versions(python_code: str, requirements_txt: str) -> str: # things installed by both 'qiskit' and 'qiskit[visualization]'. For simplicity, # we include any packages that could be relevant. module_to_packages = importlib.metadata.packages_distributions() - packages = flatten_to_list([ + packages = flatten( module_to_packages[module] for module in get_used_modules(python_code) if module not in sys.stdlib_module_names - ]) + ) package_versions = "\n".join( line for line in requirements_txt.split("\n") if re.split('[\\[~=]', line)[0].strip() in packages @@ -69,10 +70,10 @@ def get_package_versions(python_code: str, requirements_txt: str) -> str: def get_used_modules(source: str) -> Iterable[str]: # Remove Jupyter magics - source = "\n".join([ + source = "\n".join( line for line in source.split("\n") if not line.strip().startswith("%") - ]) + ) for node in ast.iter_child_nodes(ast.parse(source=source, filename="_.py")): if isinstance(node, ast.Import): for module in node.names: @@ -81,5 +82,5 @@ def get_used_modules(source: str) -> Iterable[str]: if node.module is not None: yield node.module.split(".")[0] -def flatten_to_list(i: Iterable[Iterable]) -> list: +def flatten(i: Iterable[Iterable]) -> list: return list(itertools.chain.from_iterable(i))