Incorporate feedback

This commit is contained in:
Frank Harkins 2025-02-22 08:42:57 +00:00
parent 39fdc0a45d
commit d7b618bb02
2 changed files with 18 additions and 10 deletions

View File

@ -75,15 +75,22 @@ async def execute_notebook(job: NotebookJob) -> Result:
print(f"▶️ Executing {job.path}")
working_directory = tempfile.TemporaryDirectory()
possible_exceptions = (
execution_exceptions = (
nbclient.exceptions.CellExecutionError,
nbclient.exceptions.CellTimeoutError,
)
try:
nb = await _execute_notebook(job, working_directory.name)
except possible_exceptions as err:
except execution_exceptions as err:
print(f"❌ Problem in {job.path}:\n{err}")
return Result(False, reason="Exception in notebook")
except SyntaxError as err:
print(
f"❌ Problem in {job.path}:\n"
f"Error parsing code while post-processing (line {err.lineno}):\n"
f" {err.text}"
)
return Result(False, reason="Invalid syntax")
finally:
working_directory.cleanup()

View File

@ -11,12 +11,13 @@
# been altered from the originals.
import ast
import re
import sys
import importlib
import itertools
import re
import sys
from pathlib import Path
from typing_extensions import Iterable
from collections.abc import Iterable
import nbformat
from squeaky import clean_notebook
@ -56,11 +57,11 @@ def get_package_versions(python_code: str, requirements_txt: str) -> str:
# things installed by both 'qiskit' and 'qiskit[visualization]'. For simplicity,
# we include any packages that could be relevant.
module_to_packages = importlib.metadata.packages_distributions()
packages = flatten_to_list([
packages = flatten(
module_to_packages[module]
for module in get_used_modules(python_code)
if module not in sys.stdlib_module_names
])
)
package_versions = "\n".join(
line for line in requirements_txt.split("\n")
if re.split('[\\[~=]', line)[0].strip() in packages
@ -69,10 +70,10 @@ def get_package_versions(python_code: str, requirements_txt: str) -> str:
def get_used_modules(source: str) -> Iterable[str]:
# Remove Jupyter magics
source = "\n".join([
source = "\n".join(
line for line in source.split("\n")
if not line.strip().startswith("%")
])
)
for node in ast.iter_child_nodes(ast.parse(source=source, filename="_.py")):
if isinstance(node, ast.Import):
for module in node.names:
@ -81,5 +82,5 @@ def get_used_modules(source: str) -> Iterable[str]:
if node.module is not None:
yield node.module.split(".")[0]
def flatten_to_list(i: Iterable[Iterable]) -> list:
def flatten(i: Iterable[Iterable]) -> list:
return list(itertools.chain.from_iterable(i))