[BE]: Enable F821 and fix bugs (#116579)

Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2024-01-01 08:40:46 +00:00 committed by PyTorch MergeBot
parent 6c02520466
commit bd10fea79a
86 changed files with 281 additions and 194 deletions

View File

@ -7,7 +7,7 @@ max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
# fix these lints in the future
E275,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
@ -31,6 +31,8 @@ ignore =
TOR102,
per-file-ignores =
__init__.py: F401
test/**: F821
test/**/__init__.py: F401,F821
torch/utils/cpp_extension.py: B950
torchgen/api/types/__init__.py: F401,F403
torchgen/executorch/api/types/__init__.py: F401,F403

View File

@ -2594,7 +2594,7 @@ init_command = [
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'black==23.3.0',
'black==23.12.1',
'ufmt==2.1.0',
'usort==1.0.6',
]

View File

@ -61,6 +61,7 @@ from torch._dynamo.testing import (
reset_rng_state,
same,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
try:
from torch._dynamo.utils import clone_inputs, graph_break_reasons

View File

@ -28,7 +28,7 @@ SUPPORTED_OPS = {"add_op"}
def parse_op_args(op):
op_list = ops.split(",")
op_list = op.split(",")
def print_results(result):

View File

@ -768,14 +768,14 @@ class SetCriterion(nn.Module):
src_masks = outputs["pred_masks"]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(
target_masks, valid = nested_tensor_from_tensor_list( # noqa: F821
[t["masks"] for t in targets]
).decompose()
target_masks = target_masks.to(src_masks)
src_masks = src_masks[src_idx]
# upsample predictions to the target size
src_masks = interpolate(
src_masks = interpolate( # noqa: F821
src_masks[:, None],
size=target_masks.shape[-2:],
mode="bilinear",
@ -786,8 +786,10 @@ class SetCriterion(nn.Module):
target_masks = target_masks[tgt_idx].flatten(1)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss( # noqa: F821
src_masks, target_masks, num_boxes
), # noqa: F821
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821
}
return losses

View File

@ -5,6 +5,8 @@ from benchmark_test_generator import _register_test
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
from .benchmark_core import TestConfig
"""Caffe2 performance microbenchmarks.
This module contains Caffe2-specific functionalities for performance

View File

@ -1,3 +1,4 @@
import argparse
import bisect
import itertools
import os

View File

@ -2,7 +2,7 @@ import argparse
import sys
import torch
from utils import Event, gen_sparse_coo, gen_sparse_csr
from utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr
def test_sparse_csr(m, n, k, nnz, test_count):

View File

@ -28,7 +28,8 @@ class Concat2D2InputBench(benchmark.Benchmark):
def reference(self):
return np.concatenate(
(self.numpy(self.input1), self.numpy(self.input2)), axis=concat_dim
(self.numpy(self.input1), self.numpy(self.input2)),
axis=self.concat_dim,
)
def config(self):
@ -97,7 +98,8 @@ class ConcatGraphOptBench(benchmark.Benchmark):
def reference(self):
return np.concatenate(
(self.numpy(self.input1), self.numpy(self.input2)), axis=concat_dim
(self.numpy(self.input1), self.numpy(self.input2)),
axis=self.concat_dim,
)
def config(self):

View File

@ -3,6 +3,14 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dis
import inspect
from dataclasses import dataclass
from typing import Union
from . import DimList
_vmap_levels = []
@ -22,11 +30,12 @@ class Dim:
def __del__(self):
if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
while (
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
not _vmap_levels[-1].alive
and current_level() == _vmap_levels[-1].level # noqa: F821
):
_vmap_decrement_nesting()
_vmap_decrement_nesting() # noqa: F821
_vmap_levels.pop()
@property
@ -36,9 +45,11 @@ class Dim:
@size.setter
def size(self, size: int):
from . import DimensionBindError
if self._size is None:
self._size = size
self._vmap_level = _vmap_increment_nesting(size, "same")
self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
self._vmap_stack = len(_vmap_levels)
_vmap_levels.append(LevelInfo(self._vmap_level))

View File

@ -243,7 +243,7 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
def positional(self, *dims):
from . import Dim, Tensor
from . import Dim, DimensionBindError, Tensor
ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist()

View File

@ -87,7 +87,7 @@ def upload_file(
waiting_time = datetime.datetime.now() - start_time
if waiting_time > datetime.timedelta(seconds=MAX_UPLOAD_WAIT_IN_SECOND):
raise Exception(
f"Uploading {filename} is taking longer than {MAX_WAIT_IN_SECOND} seconds, terminating..."
f"Uploading {filename} is taking longer than {MAX_UPLOAD_WAIT_IN_SECOND} seconds, terminating..."
)
r = client.get_upload(arn=upload_arn)

View File

@ -40,7 +40,6 @@ ignore = [
"E741",
"EXE001",
"F405",
"F821",
"F841",
# these ignores are from flake8-logging-format; please fix!
"G101",
@ -116,7 +115,15 @@ select = [
]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"__init__.py" = [
"F401",
]
"test/typing/reveal/**" = [
"F821",
]
"test/torch_np/numpy_tests/**" = [
"F821",
]
"test/jit/**" = [
"PLR0133", # tests require this for JIT
"PYI",

View File

@ -6,6 +6,7 @@ import io
import itertools
import pickle
import sys
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import rpc

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch

View File

@ -6,6 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import unittest
import uuid

View File

@ -7,6 +7,7 @@
# LICENSE file in the root directory of this source tree.
import os
import unittest
import sys
import etcd
from torch.distributed.elastic.rendezvous.etcd_rendezvous import (

View File

@ -4,14 +4,15 @@
import torch
import torch.distributed as dist
import contextlib
import copyreg
import os
import sys
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
import copyreg
import os
import contextlib
from torch import multiprocessing
import torch.multiprocessing.reductions as TorchMpReductions
import torch.distributed.rpc as rpc

View File

@ -33,6 +33,7 @@ AOTInductorModelRunner = load_test_module(
"inductor.test_aot_inductor"
).AOTInductorModelRunner
import sys
if not dist.is_available():
print("distributed package not available, skipping tests", file=sys.stderr)

View File

@ -30,6 +30,7 @@ import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
import torch.nn.functional as F
import torch.testing._internal.common_utils as common
from typing import Dict, List
from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
from torch import nn
from torch._C._distributed_c10d import OpType

View File

@ -105,7 +105,7 @@ class TestNCCL(TestCase):
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
@skip_but_pass_in_sandcastle_if(
TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,
TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16, # noqa: F821
"Skip bfloat16 test for ROCm < 3.5",
)
@dtypes(*datatypes)

View File

@ -778,7 +778,7 @@ class TimeoutTest(TestCase):
else:
my_store.wait(["foo"], datetime.timedelta(seconds=10))
rank_res[rank] = True
except Error as e:
except Error as e: # noqa: F821
rank_res[rank] = e
time.sleep(1)

View File

@ -631,7 +631,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
@staticmethod
def jvp(ctx, x_t):
if jvp_err:
if jvp_err: # noqa: F821
return x_t
else:
return x_t.mul_(2)
@ -647,7 +647,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
@staticmethod
def jvp(ctx, x_t, y_t):
return x_t + y_t, fn(x_t)
return x_t + y_t, fn(x_t) # noqa: F821
class MyFn3(torch.autograd.Function):
@staticmethod

View File

@ -79,12 +79,12 @@ if TEST_Z3:
unittest.expectedFailure(
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
# Ref: https://github.com/sympy/sympy/issues/25146
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821
)
unittest.expectedFailure(
# Test is only valid without dynamic shapes
DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes
DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes # noqa: F821
)
if __name__ == "__main__":

View File

@ -818,7 +818,7 @@ class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
)
]
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
return _NewEmptyTensorOp.apply(x, output_shape) # noqa: F821
class ModuleNameString(torch.nn.Module):

View File

@ -1216,7 +1216,7 @@ class TestAutogradFunction(TestCase):
grad_y = torch.randn_like(x)
def h(x, grad_y):
_, vjp_fn = vjp(f, x)
_, vjp_fn = vjp(f, x) # noqa: F821
gx, = vjp_fn(grad_y)
return gx
@ -1255,7 +1255,7 @@ class TestAutogradFunctionVmapAPI(TestCase):
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
input_np = to_numpy(input) # noqa: F821
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput
@ -1277,7 +1277,7 @@ class TestAutogradFunctionVmapAPI(TestCase):
@staticmethod
def forward(input):
input_np = to_numpy(input)
input_np = to_numpy(input) # noqa: F821
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput

View File

@ -2,13 +2,14 @@ r'''
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
rely on it for anything!**
'''
from torch.fx import Graph, GraphModule
from torch.fx import Graph, GraphModule, Node
from torch.fx.graph import map_arg
from torch.fx.proxy import Proxy
import sys
import torch
from torch.nn.utils import fuse_conv_bn_weights
import operator
from typing import Optional
# can be a
# module type, a builtin function, or a string to match target
@ -263,7 +264,7 @@ class Quantizer:
def copy_recursive(node):
def load_or_emit(n):
if n.name in env or e.name in quant_env:
if n.name in env or e.name in quant_env: # noqa: F821
return load_arg(n, quantized=False)
else:
return copy_recursive(n)

View File

@ -10,7 +10,7 @@ if IS_WINDOWS and IS_CI:
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821
import unittest
from typing import List

View File

@ -116,6 +116,8 @@ vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
libfoo = None
f32 = torch.float32
def run_fw_bw_and_get_code(fn):
def run_with_backward():

View File

@ -14,6 +14,7 @@ sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
from typing import List
from torch import Tensor
from torch.jit import Future
class TestAsync(JitTestCase):
def test_async_python(self):

View File

@ -696,12 +696,12 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
}
self.add = torch._C._jit_to_backend(
"backend_with_compiler_demo",
torch.jit.script(ModuleAdd()),
torch.jit.script(ModuleAdd()), # noqa: F821
compile_spec,
)
self.sub = torch._C._jit_to_backend(
"backend_with_compiler_demo",
torch.jit.script(ModuleAdd()),
torch.jit.script(ModuleAdd()), # noqa: F821
compile_spec,
)
@ -715,7 +715,7 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
def setUp(self):
super().setUp()
self.module = CompModule()
self.module = CompModule() # noqa: F821
self.scripted_module = torch.jit.script(self.module)
buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -747,7 +747,7 @@ class AddedAttributesTest(JitBackendTestCase):
input = [(torch.ones(5),)]
pre_bundled = self.lowered_module(*input[0])
# Attach bundled inputs which adds several attributes and functions to the model
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input)
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input) # noqa: F821
post_bundled = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
# Save and load the lowered module.
self.save_load()

View File

@ -90,7 +90,7 @@ class TestBuiltins(JitTestCase):
def fn(x):
a = x ** 2
del a
return a
return a # noqa: F821
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script
@ -104,7 +104,7 @@ class TestBuiltins(JitTestCase):
@torch.jit.script
def fn(x):
a = x ** 2
del b
del b # noqa: F821
return a
def test_del_multiple_operands(self):

View File

@ -49,9 +49,9 @@ class TestException(TestCase):
def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
raise ArbitraryError(a, "hi") # noqa: F821
if 1 == 2:
raise ArbitraryError
raise ArbitraryError # noqa: F821
return a
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):

View File

@ -646,7 +646,7 @@ class TestFreezing(JitTestCase):
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('b'))
with self.assertRaisesRegex(AttributeError, "TestModule (.*) does not have a field with name '_forward'"):
mf._forward(x)
mf._forward(x) # noqa: F821
def test_freeze_module_with_inplace_mutable(self):
class FreezeMe(torch.jit.ScriptModule):

View File

@ -2629,7 +2629,7 @@ class TestScriptList(JitTestCase):
return self
def __next__(self):
if self.value == limit:
if self.value == limit: # noqa: F821
raise StopIteration()
ret = self.value

View File

@ -280,7 +280,7 @@ class TestMisc(JitTestCase):
self.checkScript(foo, ())
def annotated_list_fail():
return expects_intlist(torch.jit.annotate([], List[Tensor]))
return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821
with self.assertRaises(RuntimeError):
torch.jit.script(annotated_list_fail)

View File

@ -76,7 +76,7 @@ class TestRecursiveScript(JitTestCase):
def test_failed_function_compilation(self):
def fn(x):
return i_dont_exist
return i_dont_exist # noqa: F821
class M(torch.nn.Module):
def __init__(self, fn):

View File

@ -5,7 +5,7 @@ import unittest
from textwrap import dedent
import torch
from torch import nn
from torch import nn, Tensor
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor

View File

@ -1911,7 +1911,7 @@ class TestTracer(JitTestCase):
def test_non_tensor_tracing(self):
def f(x):
return x + param
return x + param # noqa: F821
with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"):
torch.jit.trace(f, (1,))

View File

@ -382,7 +382,7 @@ class TestTyping(JitTestCase):
@torch.jit.script
def outer_scope_cannot_access_comprehension_variables():
d = {i : chr(i + 65) for i in range(4)}
i = i + 1
i = i + 1 # noqa: F821
def test_for_tuple_assign(self):
def test_simple_assign(x):
@ -596,7 +596,7 @@ class TestTyping(JitTestCase):
def test_namedtuple_error_source_attribution(self):
class _NamedTupleBadMemberType(NamedTuple):
f1: torch.Tensor
f2: "ABadForwardRefType"
f2: "ABadForwardRefType" # noqa: F821
make_global(_NamedTupleBadMemberType) # see [local resolution in python]

View File

@ -3,6 +3,7 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch import float32, float16
import torch._lazy
import torch._lazy.ts_backend

View File

@ -3,6 +3,7 @@
import torch
import torch.utils.bundled_inputs
import io
from tempfile import TemporaryFileName
from typing import Dict, List
import inspect
from torch.testing import FileCheck

View File

@ -100,4 +100,4 @@ class TestLiteFuseFx(QuantizationLiteTestCase):
if __name__ == "__main__":
run_tests()
run_tests() # noqa: F821

View File

@ -13609,7 +13609,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
)
self.bano1 = torch_geometric_nn.BatchNorm(512)
self.relu = torch.nn.ReLU()
self.dense1 = torch.nn.Seq(Lin(512, 1))
self.dense1 = torch.nn.Seq(Lin(512, 1)) # noqa: F821
self.sigmoid = torch.nn.Sigmoid()
def forward(self, coords0, coords1, edge_from, edge_to):

View File

@ -66,7 +66,7 @@ b=8, k=2
"""
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -77,9 +77,9 @@ b=4, k=2
"""
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test) # noqa: F821
print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
"""

View File

@ -12,7 +12,7 @@ train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss() # noqa: F821
float_model = resnet18(pretrained=True)
float_model.eval()
@ -26,8 +26,8 @@ model_to_quantize.eval()
Prepare model QAT for specified qconfig for torch.nn.Linear
"""
def prepare_qat_linear(qconfig):
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]} # noqa: F821
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers # noqa: F821
training_loop(prepared_model, criterion, data_loader)
prepared_model.eval()
return prepared_model
@ -37,7 +37,7 @@ Prepare model with uniform activation, uniform weight
b=8, k=2
"""
prepared_model = prepare_qat_linear(uniform_qconfig_8bit)
prepared_model = prepare_qat_linear(uniform_qconfig_8bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -47,7 +47,7 @@ Prepare model with uniform activation, uniform weight
b=4, k=2
"""
prepared_model = prepare_qat_linear(uniform_qconfig_4bit)
prepared_model = prepare_qat_linear(uniform_qconfig_4bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -57,7 +57,7 @@ Prepare model with uniform activation, APoT weight
(b=8, k=2)
"""
prepared_model = prepare_qat_linear(apot_weights_qconfig_8bit)
prepared_model = prepare_qat_linear(apot_weights_qconfig_8bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #2 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -67,7 +67,7 @@ Prepare model with uniform activation, APoT weight
(b=4, k=2)
"""
prepared_model = prepare_qat_linear(apot_weights_qconfig_4bit)
prepared_model = prepare_qat_linear(apot_weights_qconfig_4bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #2 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -78,7 +78,7 @@ Prepare model with APoT activation and weight
(b=8, k=2)
"""
prepared_model = prepare_qat_linear(apot_qconfig_8bit)
prepared_model = prepare_qat_linear(apot_qconfig_8bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #3 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
@ -88,7 +88,7 @@ Prepare model with APoT activation and weight
(b=4, k=2)
"""
prepared_model = prepare_qat_linear(apot_qconfig_4bit)
prepared_model = prepare_qat_linear(apot_qconfig_4bit) # noqa: F821
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print(f"Model #3 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: quantization"]
from typing import Set
import torch
import torch.nn as nn

View File

@ -2560,7 +2560,7 @@ class TestQuantizeFx(QuantizationTestCase):
}
with self.assertRaises(ValueError) as context:
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # noqa: F821
self.assertTrue(
'Expected qconfig_dict to have the following keys:' in str(context.exception)
)
@ -9315,8 +9315,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, prepared),
nprocs=world_size,
args=(world_size, prepared), # noqa: F821
nprocs=world_size, # noqa: F821
join=True)
elif mode == 'qat':
assert prepared.training, 'prepared must be in training mode for qat'
@ -9361,8 +9361,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
# calibration
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, qeager),
nprocs=world_size,
args=(world_size, qeager), # noqa: F821
nprocs=world_size, # noqa: F821
join=True)
elif mode == 'qat':
assert qeager.training, 'qeager should be in training mode for qat'
@ -9526,8 +9526,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
def test_resnet18_ddp(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float()
model = models.__dict__[name](pretrained=False).eval().float()
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() # noqa: F821
model = models.__dict__[name](pretrained=False).eval().float() # noqa: F821
self._test_model_impl(
'ddp', 'resnet18', model, eager_quantizable_model)

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import Any, Dict
import torch
import torch._export as export
@ -249,7 +250,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
eps=2**-12
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
MinMaxObserver
)
@ -266,7 +267,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(

View File

@ -436,7 +436,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
) -> Tuple[Tensor, Tensor]:
assert (
len(obs_or_fqs) == 2
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fq)}"
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
@ -539,7 +539,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
) -> Tuple[Tensor, Tensor]:
assert (
len(obs_or_fqs) == 1
), f"Expecting one weight obs/fq, got: {len(obs_or_fq)}"
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
weight_obs_or_fq = obs_or_fqs[0]
(
weight_scale,

View File

@ -2565,7 +2565,7 @@ exit(2)
if not TEST_CUDAMALLOCASYNC:
# These stat checks are specific to the native allocator.
if share_mem != "Don't share":
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"], # noqa: F821
kSmallBuffer)
else:
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]

View File

@ -696,7 +696,7 @@ for test_param in supported_tests:
if hasattr(TestExpandedWeightModule, test_name_multi_input):
raise RuntimeError('Found two tests with the same name: ' + test_name)
if decorator is not None:
fn = decorator(fn)
fn = decorator(fn) # noqa: F821
if test.test_cpu:
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
setattr(TestExpandedWeightModule, test_name_multi_input,

View File

@ -71,7 +71,7 @@ class ForeachFuncWrapper:
class InplaceForeachVersionBumpCheck:
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None:
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None: # noqa: F821
self._testcase = testcase
self._tensorlist = tensorlist
self._orig_version_counts = [t._version for t in tensorlist]

View File

@ -50,7 +50,7 @@ from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
from fx.test_gradual_type import AnnotationsTest # noqa: F401
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
@ -1875,13 +1875,13 @@ class TestFX(JitTestCase):
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
return super().call_function(n) # noqa: F821
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
return super().call_method(n) # noqa: F821
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
@ -1990,13 +1990,13 @@ class TestFX(JitTestCase):
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
return super().call_function(n) # noqa: F821
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
return super().call_method(n) # noqa: F821
transformed = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)

View File

@ -8163,7 +8163,7 @@ dedent """
@torch.jit.script
def foo(a):
return pyfunc2(a) + pyfunc(a)
return pyfunc2(a) + pyfunc(a) # noqa: F821
inputs = self._make_scalar_vars([1], torch.float)
outputs = self._make_scalar_vars([6], torch.float)

View File

@ -796,7 +796,7 @@ class TestFuser(JitTestCase):
FileCheck.check("FusionGroup").run(str(graph))
except RuntimeError as e:
if 'Failed to compile' in e.args[0]:
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
warnings.warn('CPU fuser test has failed! This is not a hard failure, ' # noqa: F821
'because the kernels sometimes trigger bugs in compilers '
'(most notably GCC 7.2).')
raise unittest.SkipTest('Failed to compile') from e

View File

@ -1289,7 +1289,7 @@ class TestTEFuser(JitTestCase):
self.assertLastGraphAllFused()
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)])
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) # noqa: F821
) from e
def test_isnan(self):
@ -2706,7 +2706,7 @@ def f({', '.join(param_names)}):
return
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', TracerWarning)
warnings.simplefilter('ignore', TracerWarning) # noqa: F821
self.te_compile(device, dtype, op)
except Exception as e:
pass

View File

@ -1389,11 +1389,11 @@ class TestAvgPool(TestCaseMPS):
return joined_x.view(1, joined_x.numel())
def _avg_pool2d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size)
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool2d(x, kernel_size) / size
def _avg_pool3d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size)
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool3d(x, kernel_size) / size
def test_avg_pool2d_with_zero_divisor(self):
@ -6520,7 +6520,7 @@ class TestMPS(TestCaseMPS):
devices += ['mps']
def _gelu_ref(X):
return X * stats.norm.cdf(X)
return X * stats.norm.cdf(X) # noqa: F821
for d in devices:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]

View File

@ -33,7 +33,7 @@ class TestSortAndSelect(TestCase):
# see above
return ((b != b) | (a <= b)).all().item()
else:
error(f'unknown order "{order}", must be "ascending" or "descending"')
error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821
are_ordered = True
for k in range(1, SIZE):

View File

@ -8924,7 +8924,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
# FIXME: All of the following should be marked as expected failures
# so that it is easier to tell when missing has been added.
# FIXME: fix all the skipped ones below!
test_namespace(torch.randn(1),
test_namespace(torch.randn(1), # noqa: F821
'as_strided_',
re.compile('^clamp_(min|max)_?$'),
'is_distributed',
@ -8946,8 +8946,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
'_autocast_to_fp32',
)
test_namespace(torch.nn)
test_namespace(torch.nn.functional, 'assert_int_or_pair')
test_namespace(torch.nn) # noqa: F821
test_namespace(torch.nn.functional, 'assert_int_or_pair') # noqa: F821
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)

View File

@ -229,7 +229,7 @@ class TestTransformers(NNTestCase):
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
encoder.eval()
@ -240,7 +240,7 @@ class TestTransformers(NNTestCase):
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported Long type exception")
self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
@ -917,7 +917,7 @@ class TestTransformers(NNTestCase):
)
if torch_encoder is not None:
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
self.decoder = torch_to_fairseq(torch_encoder, self.decoder) # noqa: F821
self.decoder = self.decoder.eval().cuda().half()
def forward(

View File

@ -17,8 +17,9 @@ import warnings
import weakref
from contextlib import contextmanager
from decimal import Decimal
from tempfile import mkstemp
from unittest import expectedFailure as xfail, skipIf as skipif
from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
import numpy
import pytest

View File

@ -258,7 +258,7 @@ class ProcessGroup:
def backend(self) -> str: ...
@property
def _timeout(self) -> timedelta: ...
@timeout.setter
@_timeout.setter
def _timeout(self, val: timedelta) -> None: ...
class BackendType(Enum):
@ -507,11 +507,11 @@ class ProcessGroupNCCL(ProcessGroup):
def backend(self) -> str: ...
@property
def _timeout(self) -> timedelta: ...
@timeout.setter
@_timeout.setter
def _timeout(self, val: timedelta) -> None: ...
@property
def _is_high_priority_stream(self) -> bool: ...
@is_high_priority_stream.setter
@_is_high_priority_stream.setter
def _is_high_priority_stream(self, val: bool) -> None: ...
def __init__(

View File

@ -1467,7 +1467,7 @@ if TYPE_CHECKING:
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
# Fixup segment_reduce visibility
_segment_reduce = segment_reduce
del segment_reduce
del segment_reduce # noqa: F821
# Ops not to be exposed in `torch` namespace,
# mostly helper ops.

View File

@ -1577,6 +1577,9 @@ class BuiltinVariable(VariableTracker):
# not broadcastable, can't be compared
_unimplemented()
tensor_cls = left if isinstance(left, TensorVariable) else right
proxy = tx.output.create_proxy(
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
)
return wrap_fx_proxy_cls(
type(tensor_cls), # handle Ndarrays and Tensors
tx,

View File

@ -2,7 +2,7 @@ import functools
import inspect
import itertools
import types
from typing import Dict, List
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import torch
@ -13,6 +13,9 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, make_cell
from .base import typestr, VariableTracker
if TYPE_CHECKING:
from torch._guards import Source
def wrap_bound_arg(tx, val, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.

View File

@ -3,7 +3,7 @@ MAX_CYCLE = 3000
import itertools
import operator
from typing import List, Optional
from typing import Dict, List, Optional
from .. import polyfill, variables
from ..exc import unimplemented

View File

@ -1149,8 +1149,8 @@ def aot_export_joint_simple(
if config.debug_assert:
# Smoke test that after partitioning, we can run the forward without any calling convention changes.
fw_module, bw_module = aot_config.default_partition(
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
fw_module, bw_module = aot_config.default_partition( # noqa: F821
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821
)
# Attempt to run the fw_module with the original user inputs
fake_mode = detect_fake_mode(args)

View File

@ -14,7 +14,7 @@ import itertools
import sympy
from collections import defaultdict
from torch.fx.passes import graph_drawer
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Set, Tuple, Union
from .compile_utils import fx_graph_cse, get_aten_target
from . import config
import functools
@ -325,7 +325,8 @@ def _size_of(node: fx.Node) -> int:
# Only needed since we don't always trace with fake tensors.
if 'tensor_meta' in node.meta:
metadata = node.meta['tensor_meta']
numel = _prod(map(to_size_hint, metadata.shape))
# TODO: What is to_size_hint suppose to be?
numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821
dtype = metadata.dtype
else:
return 0

View File

@ -17,6 +17,7 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -40,6 +41,9 @@ from ..utils import (
)
from ..virtualized import ops, OpsValue, V
if TYPE_CHECKING:
from ..ir import TensorBox
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
@ -1288,7 +1292,7 @@ class ChoiceCaller:
def hash_key(self) -> str:
raise NotImplementedError()
def output_node(self) -> "TensorBox": # type: ignore[name-defined]
def output_node(self) -> "TensorBox":
raise NotImplementedError()

View File

@ -180,7 +180,7 @@ def gen_ops() -> List[Any]:
def dtype_match(
torch_dtype: Optional[torch.dtype],
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()

View File

@ -264,7 +264,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
return res
@staticmethod
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
@ -277,8 +277,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def flip_cutlass_layout(
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
@ -312,7 +312,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
return result
@staticmethod
def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined]
def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821
"""
returns True if the op is capable of flexible epilogue fusions
using epilogue visitor trees.
@ -345,7 +345,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def define_gemm_instance(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
output_buffer_name: str,
epilogue_nodes: Optional[List[IRNode]] = None,
) -> Tuple[str, str]:
@ -408,8 +408,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def swap_XW(
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
# Swap X and W in GemmOperation.
new_op = copy.deepcopy(op)
new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout)
@ -421,8 +421,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def filter_op(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
@ -508,7 +508,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
return None
return op
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
@ -619,7 +619,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def render( # type: ignore[override]
self,
kernel: CUDATemplateKernel,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined]
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
template_buffer_node: Optional[CUDATemplateBuffer] = None,
epilogue_nodes: Optional[List[IRNode]] = None,
**kwargs,

View File

@ -3312,7 +3312,7 @@ class CUDATemplateBuffer(TemplateBuffer):
inputs,
make_kernel_render,
workspace_size: int,
template: "CUDATemplate", # type: ignore[name-defined]
template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821
):
super().__init__(layout, inputs, make_kernel_render)
# Global memory (in bytes) needed for this template.

View File

@ -375,7 +375,7 @@ def create_submodule_from_subgraph(
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
# for any fusion pattern which has them for a node that is not the
# first node.
cur_args_copy = [cur_node_copy] # type: ignore[has-type]
cur_args_copy = [cur_node_copy] # type: ignore[has-type] # noqa: F821
if len(cur_node_orig.args) > 1:
for arg in cur_node_orig.args[1:]:

View File

@ -1,7 +1,7 @@
import dataclasses
import itertools
import operator
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING
import torch
from torch.fx import Graph, GraphModule, Node
@ -26,6 +26,8 @@ from .utils import (
get_aten_graph_module,
)
if TYPE_CHECKING:
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
__all__ = [] # type: ignore[var-annotated]
@ -259,7 +261,7 @@ def _get_folded_quantized_qat_conv_bn_pattern(
return _folded_quantized_qat_conv_bn_pattern
def _has_conv_bias_filter(
match: "InternalMatch", # type: ignore[name-defined]
match: "InternalMatch",
original_graph: Graph,
pattern_graph: Graph,
) -> bool:
@ -273,7 +275,7 @@ def _has_conv_bias_filter(
raise ValueError("Could not find conv node in matched conv + bn pattern")
def _no_conv_bias_filter(
match: "InternalMatch", # type: ignore[name-defined]
match: "InternalMatch",
original_graph: Graph,
pattern_graph: Graph,
) -> bool:

View File

@ -20,6 +20,7 @@ from typing import (
Set,
Tuple,
Type,
TYPE_CHECKING,
)
import torch
@ -44,6 +45,9 @@ from .api import (
StateDictType,
)
if TYPE_CHECKING:
from ._flat_param import FlatParamHandle
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
FSDP_FLATTENED = "_fsdp_flattened"

View File

@ -5,12 +5,15 @@ import math
import sys
import weakref
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch._subclasses.fake_tensor import FakeTensor
from .exported_program import ExportedProgram
if TYPE_CHECKING:
from ..fx.experimental.symbolic_shapes import StrictMinMaxConstraint
__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]
@ -118,7 +121,7 @@ class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
"""
# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: "StrictMinMaxConstraint" # type: ignore[name-defined]
constraint_range: "StrictMinMaxConstraint"
# Represent that `constraint_range` is shared with another _ConstraintTarget, which
# typically arises because of a specified equality with another dynamic dimension.
shared: Optional[_ConstraintTarget] = None

View File

@ -15,7 +15,21 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from typing import Any, cast, Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable
from typing import (
Any,
cast,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
TYPE_CHECKING
)
import torch
import torch.fx
@ -44,6 +58,9 @@ from torch._utils_internal import signpost_event
from torch._logging import LazyString
if TYPE_CHECKING:
from torch._dynamo.source import TensorPropertySource
InputList = List
DimList = List
@ -2066,7 +2083,7 @@ class ShapeEnv:
source: Source,
symbolic_context: SymbolicContext
) -> List[sympy.Expr]:
return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context)
return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context) # noqa: F821
def _produce_dyn_sizes_from_int_tuple(self,
tensor_size: Tuple[int],

View File

@ -6,9 +6,12 @@ from ._compatibility import compatibility
import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
@compatibility(is_backward_compatible=True)
@ -202,7 +205,7 @@ def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
"""
@ -222,7 +225,7 @@ def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:

View File

@ -2124,7 +2124,7 @@ class Module:
if child is not None:
child_prefix = prefix + name + '.'
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
load(child, child_state_dict, child_prefix)
load(child, child_state_dict, child_prefix) # noqa: F821
# Note that the hook can modify missing_keys and unexpected_keys.
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

View File

@ -5,7 +5,7 @@
from __future__ import annotations
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union
import torch._dynamo
import torch.fx
@ -13,6 +13,9 @@ import torch.onnx
from torch.onnx._internal import _beartype, exporter, io_adapter
from torch.onnx._internal.diagnostics import infra
if TYPE_CHECKING:
from torch.export.exported_program import ExportedProgram
class TorchExport(exporter.FXGraphExtractor):
"""Generates a FX GraphModule using torch.export API
@ -31,7 +34,7 @@ class TorchExport(exporter.FXGraphExtractor):
def generate_fx(
self,
options: exporter.ResolvedExportOptions,
model: "ExportedProgram", # type: ignore[name-defined]
model: "ExportedProgram", # type: ignore[override]
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:

View File

@ -23,17 +23,17 @@ from .lbfgs import LBFGS
from . import lr_scheduler
from . import swa_utils
del adadelta
del adagrad
del adam
del adamw
del sparse_adam
del adamax
del asgd
del sgd
del radam
del rprop
del rmsprop
del optimizer
del nadam
del lbfgs
del adadelta # noqa: F821
del adagrad # noqa: F821
del adam # noqa: F821
del adamw # noqa: F821
del sparse_adam # noqa: F821
del adamax # noqa: F821
del asgd # noqa: F821
del sgd # noqa: F821
del radam # noqa: F821
del rprop # noqa: F821
del rmsprop # noqa: F821
del optimizer # noqa: F821
del nadam # noqa: F821
del lbfgs # noqa: F821

View File

@ -206,7 +206,7 @@ def run_ddp(rank, world_size, prepared):
prepared.to(rank)
model_with_ddp = prepared
optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)
train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821
ddp_cleanup()

View File

@ -199,7 +199,7 @@ TestEnvironment.def_flag(
include_in_repro=False)
# NB: enabled by default unless in an fbcode context.
TestEnvironment.def_flag("PRINT_REPRO_ON_FAILURE", env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
default=(not IS_FBCODE), include_in_repro=False)
default=(not IS_FBCODE), include_in_repro=False) # noqa: F821
DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
DEFAULT_SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
@ -760,7 +760,7 @@ parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
default=_get_test_report_path() if IS_CI else None) # noqa: F821
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
@ -1289,7 +1289,7 @@ TestEnvironment.def_flag("TEST_SKIP_FAST", env_var="PYTORCH_TEST_SKIP_FAST")
TestEnvironment.def_flag("TEST_WITH_CROSSREF", env_var="PYTORCH_TEST_WITH_CROSSREF")
TestEnvironment.def_flag("TEST_SKIP_CUDAGRAPH", env_var="PYTORCH_TEST_SKIP_CUDAGRAPH")
TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( # noqa: F821
(torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 11) or
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
)
@ -1303,7 +1303,7 @@ if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
def skipIfCrossRef(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_CROSSREF:
if TEST_WITH_CROSSREF: # noqa: F821
raise unittest.SkipTest("test doesn't currently with crossref")
else:
fn(*args, **kwargs)
@ -1320,25 +1320,25 @@ TestEnvironment.def_flag("TEST_WITH_TORCHINDUCTOR", env_var="PYTORCH_TEST_WITH_I
# AOT_EAGER not tested in ci, useful for debugging
TestEnvironment.def_flag("TEST_WITH_AOT_EAGER", env_var="PYTORCH_TEST_WITH_AOT_EAGER")
TestEnvironment.def_flag("TEST_WITH_TORCHDYNAMO", env_var="PYTORCH_TEST_WITH_DYNAMO",
implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER)
implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER) # noqa: F821
if TEST_WITH_TORCHDYNAMO:
if TEST_WITH_TORCHDYNAMO: # noqa: F821
import torch._dynamo
# Do not spend time on helper functions that are called with different inputs
torch._dynamo.config.accumulated_cache_size_limit = 8
# Do not log compilation metrics from unit tests
torch._dynamo.config.log_compilation_metrics = False
if TEST_WITH_TORCHINDUCTOR:
if TEST_WITH_TORCHINDUCTOR: # noqa: F821
import torch._inductor.config
torch._inductor.config.fallback_random = True
def xpassIfTorchDynamo(func):
return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func) # noqa: F821
def xfailIfTorchDynamo(func):
return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func # noqa: F821
def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
@ -1346,14 +1346,14 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
if not isinstance(fn, type):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_TORCHDYNAMO:
if TEST_WITH_TORCHDYNAMO: # noqa: F821
raise unittest.SkipTest(msg)
else:
fn(*args, **kwargs)
return wrapper
assert(isinstance(fn, type))
if TEST_WITH_TORCHDYNAMO:
if TEST_WITH_TORCHDYNAMO: # noqa: F821
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
@ -1363,7 +1363,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
return decorator
def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
condition=TEST_WITH_TORCHINDUCTOR):
condition=TEST_WITH_TORCHINDUCTOR): # noqa: F821
def decorator(fn):
if not isinstance(fn, type):
@wraps(fn)
@ -1427,7 +1427,7 @@ def markDynamoStrictTest(cls_or_func=None, nopython=False):
def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR) # noqa: F821
def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
def decorator(fn):
@ -1535,7 +1535,7 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
if TEST_WITH_ROCM: # noqa: F821
raise unittest.SkipTest(reason)
else:
return fn(*args, **kwargs)
@ -1547,7 +1547,7 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
def runOnRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
if TEST_WITH_ROCM: # noqa: F821
fn(*args, **kwargs)
else:
raise unittest.SkipTest("test currently only works on the ROCm stack")
@ -1567,7 +1567,7 @@ def skipIfRocmVersionLessThan(version=None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
if TEST_WITH_ROCM: # noqa: F821
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
@ -1811,7 +1811,7 @@ def skipIfTBB(message="This test makes TBB sad"):
def slowTest(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if not TEST_WITH_SLOW:
if not TEST_WITH_SLOW: # noqa: F821
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
else:
fn(*args, **kwargs)
@ -2177,7 +2177,7 @@ try:
verbosity=hypothesis.Verbosity.verbose))
hypothesis.settings.load_profile(
"pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
"pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev') # noqa: F821
)
except ImportError:
print('Fail to import hypothesis in common_utils, tests are not derandomized')
@ -2218,10 +2218,10 @@ def check_if_enable(test: unittest.TestCase):
if any(matches_test(x) for x in slow_tests_dict.keys()):
getattr(test, test._testMethodName).__dict__['slow_test'] = True
if not TEST_WITH_SLOW:
if not TEST_WITH_SLOW: # noqa: F821
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
if not IS_SANDCASTLE:
if not IS_SANDCASTLE: # noqa: F821
should_skip = False
skip_msg = ""
@ -2233,11 +2233,11 @@ def check_if_enable(test: unittest.TestCase):
"win": IS_WINDOWS,
"windows": IS_WINDOWS,
"linux": IS_LINUX,
"rocm": TEST_WITH_ROCM,
"asan": TEST_WITH_ASAN,
"dynamo": TEST_WITH_TORCHDYNAMO,
"inductor": TEST_WITH_TORCHINDUCTOR,
"slow": TEST_WITH_SLOW,
"rocm": TEST_WITH_ROCM, # noqa: F821
"asan": TEST_WITH_ASAN, # noqa: F821
"dynamo": TEST_WITH_TORCHDYNAMO, # noqa: F821
"inductor": TEST_WITH_TORCHINDUCTOR, # noqa: F821
"slow": TEST_WITH_SLOW, # noqa: F821
}
invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
@ -2270,7 +2270,7 @@ def check_if_enable(test: unittest.TestCase):
" disabled tests are run"
raise unittest.SkipTest(skip_msg)
if TEST_SKIP_FAST:
if TEST_SKIP_FAST: # noqa: F821
if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
@ -2565,7 +2565,7 @@ class TestCase(expecttest.TestCase):
test_method = getattr(self, method_name, None)
if test_method is not None:
# Wraps the tested method if we should do CUDA memory check.
if TEST_CUDA_MEM_LEAK_CHECK:
if TEST_CUDA_MEM_LEAK_CHECK: # noqa: F821
self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
# FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
if self._do_cuda_memory_leak_check and not IS_WINDOWS:
@ -2579,7 +2579,7 @@ class TestCase(expecttest.TestCase):
if self._ignore_not_implemented_error:
self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
if PRINT_REPRO_ON_FAILURE:
if PRINT_REPRO_ON_FAILURE: # noqa: F821
env_var_prefix = TestEnvironment.repro_env_var_prefix()
try:
def _get_rel_test_path(abs_test_path):
@ -2684,7 +2684,7 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
test_cls = super_run.__self__
# Are we compiling?
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR # noqa: F821
# Is the class strict and compiling?
strict_default = False
if compiled:
@ -2716,11 +2716,11 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
else:
supress_errors = torch._dynamo.config.suppress_errors
with unittest.mock.patch("torch._dynamo.config.suppress_errors", supress_errors):
if TEST_WITH_TORCHINDUCTOR:
if TEST_WITH_TORCHINDUCTOR: # noqa: F821
super_run = torch._dynamo.optimize("inductor")(super_run)
elif TEST_WITH_AOT_EAGER:
elif TEST_WITH_AOT_EAGER: # noqa: F821
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
elif TEST_WITH_TORCHDYNAMO:
elif TEST_WITH_TORCHDYNAMO: # noqa: F821
# TorchDynamo optimize annotation
super_run = torch._dynamo.optimize("eager", nopython=nopython)(super_run)
key = f"{self.__class__.__name__}.{self._testMethodName}"
@ -2757,7 +2757,7 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
def run(self, result=None):
with contextlib.ExitStack() as stack:
if TEST_WITH_CROSSREF:
if TEST_WITH_CROSSREF: # noqa: F821
stack.enter_context(CrossRefMode())
self._run_custom(
result=result,
@ -4312,7 +4312,7 @@ def load_tests(loader, tests, pattern):
set_running_script_path()
test_suite = unittest.TestSuite()
for test_group in tests:
if not DISABLE_RUNNING_SCRIPT_CHK:
if not DISABLE_RUNNING_SCRIPT_CHK: # noqa: F821
for test in test_group:
check_test_defined_in_running_script(test)
if test_group._tests:
@ -4337,7 +4337,7 @@ GRADCHECK_NONDET_TOL = 1e-12
TestEnvironment.def_flag("TEST_WITH_SLOW_GRADCHECK", env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK")
skipIfSlowGradcheckEnv = unittest.skipIf(
TEST_WITH_SLOW_GRADCHECK,
TEST_WITH_SLOW_GRADCHECK, # noqa: F821
"Tests that don't use gradcheck don't need to run on slow_gradcheck CI"
)
@ -4353,7 +4353,7 @@ def gradcheck(fn, inputs, **kwargs):
"fast_mode": True,
}
if TEST_WITH_SLOW_GRADCHECK:
if TEST_WITH_SLOW_GRADCHECK: # noqa: F821
default_values["fast_mode"] = False
for key, value in default_values.items():
@ -4373,7 +4373,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
"fast_mode": True,
}
if TEST_WITH_SLOW_GRADCHECK:
if TEST_WITH_SLOW_GRADCHECK: # noqa: F821
default_values["fast_mode"] = False
for key, value in default_values.items():
@ -4468,7 +4468,7 @@ def skip_but_pass_in_sandcastle(reason):
skipping continuously.
"""
def decorator(func):
if not IS_SANDCASTLE:
if not IS_SANDCASTLE: # noqa: F821
func.__unittest_skip__ = True
func.__unittest_skip_why__ = reason
return func
@ -4573,7 +4573,7 @@ def skip_but_pass_in_sandcastle_if(condition, reason):
"""
def decorator(func):
if condition:
if IS_SANDCASTLE:
if IS_SANDCASTLE: # noqa: F821
@wraps(func)
def wrapper(*args, **kwargs):
print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
@ -4738,7 +4738,7 @@ class TestGradients(TestCase):
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
small_inputs_only=TEST_WITH_SLOW_GRADCHECK) # noqa: F821
for sample in samples:
if sample.broadcasts_input and is_inplace(variant):

View File

@ -7,7 +7,7 @@ import torch.jit
import torch.jit._logging
import torch.jit.frontend
from torch.testing._internal.common_nn import module_tests, new_module_tests
from torch.testing._internal.common_utils import is_iterable_of_tensors
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
import collections
from copy import deepcopy

View File

@ -3,7 +3,7 @@ import random
import unittest
from functools import partial
from itertools import chain, product
from typing import Iterable, List
from typing import Iterable, List, Tuple
import numpy as np
from numpy import inf