[BE]: Enable F821 and fix bugs (#116579)
Fixes #112371 I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579 Approved by: https://github.com/ezyang
This commit is contained in:
parent
6c02520466
commit
bd10fea79a
4
.flake8
4
.flake8
|
@ -7,7 +7,7 @@ max-line-length = 120
|
|||
# C408 ignored because we like the dict keyword argument syntax
|
||||
# E501 is not flexible enough, we're using B950 instead
|
||||
ignore =
|
||||
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
|
||||
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
|
||||
# fix these lints in the future
|
||||
E275,
|
||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||
|
@ -31,6 +31,8 @@ ignore =
|
|||
TOR102,
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
test/**: F821
|
||||
test/**/__init__.py: F401,F821
|
||||
torch/utils/cpp_extension.py: B950
|
||||
torchgen/api/types/__init__.py: F401,F403
|
||||
torchgen/executorch/api/types/__init__.py: F401,F403
|
||||
|
|
|
@ -2594,7 +2594,7 @@ init_command = [
|
|||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'--no-black-binary',
|
||||
'black==23.3.0',
|
||||
'black==23.12.1',
|
||||
'ufmt==2.1.0',
|
||||
'usort==1.0.6',
|
||||
]
|
||||
|
|
|
@ -61,6 +61,7 @@ from torch._dynamo.testing import (
|
|||
reset_rng_state,
|
||||
same,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
|
||||
try:
|
||||
from torch._dynamo.utils import clone_inputs, graph_break_reasons
|
||||
|
|
|
@ -28,7 +28,7 @@ SUPPORTED_OPS = {"add_op"}
|
|||
|
||||
|
||||
def parse_op_args(op):
|
||||
op_list = ops.split(",")
|
||||
op_list = op.split(",")
|
||||
|
||||
|
||||
def print_results(result):
|
||||
|
|
|
@ -768,14 +768,14 @@ class SetCriterion(nn.Module):
|
|||
src_masks = outputs["pred_masks"]
|
||||
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(
|
||||
target_masks, valid = nested_tensor_from_tensor_list( # noqa: F821
|
||||
[t["masks"] for t in targets]
|
||||
).decompose()
|
||||
target_masks = target_masks.to(src_masks)
|
||||
|
||||
src_masks = src_masks[src_idx]
|
||||
# upsample predictions to the target size
|
||||
src_masks = interpolate(
|
||||
src_masks = interpolate( # noqa: F821
|
||||
src_masks[:, None],
|
||||
size=target_masks.shape[-2:],
|
||||
mode="bilinear",
|
||||
|
@ -786,8 +786,10 @@ class SetCriterion(nn.Module):
|
|||
target_masks = target_masks[tgt_idx].flatten(1)
|
||||
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_mask": sigmoid_focal_loss( # noqa: F821
|
||||
src_masks, target_masks, num_boxes
|
||||
), # noqa: F821
|
||||
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821
|
||||
}
|
||||
return losses
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from benchmark_test_generator import _register_test
|
|||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
from .benchmark_core import TestConfig
|
||||
|
||||
"""Caffe2 performance microbenchmarks.
|
||||
|
||||
This module contains Caffe2-specific functionalities for performance
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import bisect
|
||||
import itertools
|
||||
import os
|
||||
|
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import sys
|
||||
|
||||
import torch
|
||||
from utils import Event, gen_sparse_coo, gen_sparse_csr
|
||||
from utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr
|
||||
|
||||
|
||||
def test_sparse_csr(m, n, k, nnz, test_count):
|
||||
|
|
|
@ -28,7 +28,8 @@ class Concat2D2InputBench(benchmark.Benchmark):
|
|||
|
||||
def reference(self):
|
||||
return np.concatenate(
|
||||
(self.numpy(self.input1), self.numpy(self.input2)), axis=concat_dim
|
||||
(self.numpy(self.input1), self.numpy(self.input2)),
|
||||
axis=self.concat_dim,
|
||||
)
|
||||
|
||||
def config(self):
|
||||
|
@ -97,7 +98,8 @@ class ConcatGraphOptBench(benchmark.Benchmark):
|
|||
|
||||
def reference(self):
|
||||
return np.concatenate(
|
||||
(self.numpy(self.input1), self.numpy(self.input2)), axis=concat_dim
|
||||
(self.numpy(self.input1), self.numpy(self.input2)),
|
||||
axis=self.concat_dim,
|
||||
)
|
||||
|
||||
def config(self):
|
||||
|
|
|
@ -3,6 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import dis
|
||||
import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
from . import DimList
|
||||
|
||||
_vmap_levels = []
|
||||
|
||||
|
||||
|
@ -22,11 +30,12 @@ class Dim:
|
|||
|
||||
def __del__(self):
|
||||
if self._vmap_level is not None:
|
||||
_vmap_active_levels[self._vmap_stack].alive = False
|
||||
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
|
||||
while (
|
||||
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
|
||||
not _vmap_levels[-1].alive
|
||||
and current_level() == _vmap_levels[-1].level # noqa: F821
|
||||
):
|
||||
_vmap_decrement_nesting()
|
||||
_vmap_decrement_nesting() # noqa: F821
|
||||
_vmap_levels.pop()
|
||||
|
||||
@property
|
||||
|
@ -36,9 +45,11 @@ class Dim:
|
|||
|
||||
@size.setter
|
||||
def size(self, size: int):
|
||||
from . import DimensionBindError
|
||||
|
||||
if self._size is None:
|
||||
self._size = size
|
||||
self._vmap_level = _vmap_increment_nesting(size, "same")
|
||||
self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
|
||||
self._vmap_stack = len(_vmap_levels)
|
||||
_vmap_levels.append(LevelInfo(self._vmap_level))
|
||||
|
||||
|
|
|
@ -243,7 +243,7 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
|||
|
||||
|
||||
def positional(self, *dims):
|
||||
from . import Dim, Tensor
|
||||
from . import Dim, DimensionBindError, Tensor
|
||||
|
||||
ptensor, levels = self._tensor, llist(self._levels)
|
||||
flat_dims = llist()
|
||||
|
|
|
@ -87,7 +87,7 @@ def upload_file(
|
|||
waiting_time = datetime.datetime.now() - start_time
|
||||
if waiting_time > datetime.timedelta(seconds=MAX_UPLOAD_WAIT_IN_SECOND):
|
||||
raise Exception(
|
||||
f"Uploading {filename} is taking longer than {MAX_WAIT_IN_SECOND} seconds, terminating..."
|
||||
f"Uploading {filename} is taking longer than {MAX_UPLOAD_WAIT_IN_SECOND} seconds, terminating..."
|
||||
)
|
||||
|
||||
r = client.get_upload(arn=upload_arn)
|
||||
|
|
|
@ -40,7 +40,6 @@ ignore = [
|
|||
"E741",
|
||||
"EXE001",
|
||||
"F405",
|
||||
"F821",
|
||||
"F841",
|
||||
# these ignores are from flake8-logging-format; please fix!
|
||||
"G101",
|
||||
|
@ -116,7 +115,15 @@ select = [
|
|||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
"__init__.py" = [
|
||||
"F401",
|
||||
]
|
||||
"test/typing/reveal/**" = [
|
||||
"F821",
|
||||
]
|
||||
"test/torch_np/numpy_tests/**" = [
|
||||
"F821",
|
||||
]
|
||||
"test/jit/**" = [
|
||||
"PLR0133", # tests require this for JIT
|
||||
"PYI",
|
||||
|
|
|
@ -6,6 +6,7 @@ import io
|
|||
import itertools
|
||||
import pickle
|
||||
import sys
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import rpc
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
import etcd
|
||||
from torch.distributed.elastic.rendezvous.etcd_rendezvous import (
|
||||
|
|
|
@ -4,14 +4,15 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import contextlib
|
||||
import copyreg
|
||||
import os
|
||||
import sys
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
import copyreg
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
from torch import multiprocessing
|
||||
import torch.multiprocessing.reductions as TorchMpReductions
|
||||
import torch.distributed.rpc as rpc
|
||||
|
|
|
@ -33,6 +33,7 @@ AOTInductorModelRunner = load_test_module(
|
|||
"inductor.test_aot_inductor"
|
||||
).AOTInductorModelRunner
|
||||
|
||||
import sys
|
||||
|
||||
if not dist.is_available():
|
||||
print("distributed package not available, skipping tests", file=sys.stderr)
|
||||
|
|
|
@ -30,6 +30,7 @@ import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
|
|||
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
|
||||
import torch.nn.functional as F
|
||||
import torch.testing._internal.common_utils as common
|
||||
from typing import Dict, List
|
||||
from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
|
||||
from torch import nn
|
||||
from torch._C._distributed_c10d import OpType
|
||||
|
|
|
@ -105,7 +105,7 @@ class TestNCCL(TestCase):
|
|||
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,
|
||||
TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16, # noqa: F821
|
||||
"Skip bfloat16 test for ROCm < 3.5",
|
||||
)
|
||||
@dtypes(*datatypes)
|
||||
|
|
|
@ -778,7 +778,7 @@ class TimeoutTest(TestCase):
|
|||
else:
|
||||
my_store.wait(["foo"], datetime.timedelta(seconds=10))
|
||||
rank_res[rank] = True
|
||||
except Error as e:
|
||||
except Error as e: # noqa: F821
|
||||
rank_res[rank] = e
|
||||
time.sleep(1)
|
||||
|
||||
|
|
|
@ -631,7 +631,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_t):
|
||||
if jvp_err:
|
||||
if jvp_err: # noqa: F821
|
||||
return x_t
|
||||
else:
|
||||
return x_t.mul_(2)
|
||||
|
@ -647,7 +647,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_t, y_t):
|
||||
return x_t + y_t, fn(x_t)
|
||||
return x_t + y_t, fn(x_t) # noqa: F821
|
||||
|
||||
class MyFn3(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
@ -79,12 +79,12 @@ if TEST_Z3:
|
|||
unittest.expectedFailure(
|
||||
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
|
||||
# Ref: https://github.com/sympy/sympy/issues/25146
|
||||
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes
|
||||
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821
|
||||
)
|
||||
|
||||
unittest.expectedFailure(
|
||||
# Test is only valid without dynamic shapes
|
||||
DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes
|
||||
DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes # noqa: F821
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -818,7 +818,7 @@ class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
|
|||
)
|
||||
]
|
||||
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
|
||||
return _NewEmptyTensorOp.apply(x, output_shape)
|
||||
return _NewEmptyTensorOp.apply(x, output_shape) # noqa: F821
|
||||
|
||||
|
||||
class ModuleNameString(torch.nn.Module):
|
||||
|
|
|
@ -1216,7 +1216,7 @@ class TestAutogradFunction(TestCase):
|
|||
grad_y = torch.randn_like(x)
|
||||
|
||||
def h(x, grad_y):
|
||||
_, vjp_fn = vjp(f, x)
|
||||
_, vjp_fn = vjp(f, x) # noqa: F821
|
||||
gx, = vjp_fn(grad_y)
|
||||
return gx
|
||||
|
||||
|
@ -1255,7 +1255,7 @@ class TestAutogradFunctionVmapAPI(TestCase):
|
|||
class NumpyCube(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(input):
|
||||
input_np = to_numpy(input)
|
||||
input_np = to_numpy(input) # noqa: F821
|
||||
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
|
||||
return torch.tensor(input_np ** 3, device=input.device), dinput
|
||||
|
||||
|
@ -1277,7 +1277,7 @@ class TestAutogradFunctionVmapAPI(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def forward(input):
|
||||
input_np = to_numpy(input)
|
||||
input_np = to_numpy(input) # noqa: F821
|
||||
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
|
||||
return torch.tensor(input_np ** 3, device=input.device), dinput
|
||||
|
||||
|
|
|
@ -2,13 +2,14 @@ r'''
|
|||
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
|
||||
rely on it for anything!**
|
||||
'''
|
||||
from torch.fx import Graph, GraphModule
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.graph import map_arg
|
||||
from torch.fx.proxy import Proxy
|
||||
import sys
|
||||
import torch
|
||||
from torch.nn.utils import fuse_conv_bn_weights
|
||||
import operator
|
||||
from typing import Optional
|
||||
|
||||
# can be a
|
||||
# module type, a builtin function, or a string to match target
|
||||
|
@ -263,7 +264,7 @@ class Quantizer:
|
|||
|
||||
def copy_recursive(node):
|
||||
def load_or_emit(n):
|
||||
if n.name in env or e.name in quant_env:
|
||||
if n.name in env or e.name in quant_env: # noqa: F821
|
||||
return load_arg(n, quantized=False)
|
||||
else:
|
||||
return copy_recursive(n)
|
||||
|
|
|
@ -10,7 +10,7 @@ if IS_WINDOWS and IS_CI:
|
|||
)
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
|
|
|
@ -116,6 +116,8 @@ vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
|
|||
|
||||
libfoo = None
|
||||
|
||||
f32 = torch.float32
|
||||
|
||||
|
||||
def run_fw_bw_and_get_code(fn):
|
||||
def run_with_backward():
|
||||
|
|
|
@ -14,6 +14,7 @@ sys.path.append(pytorch_test_dir)
|
|||
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
|
||||
from typing import List
|
||||
from torch import Tensor
|
||||
from torch.jit import Future
|
||||
|
||||
class TestAsync(JitTestCase):
|
||||
def test_async_python(self):
|
||||
|
|
|
@ -696,12 +696,12 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
|
|||
}
|
||||
self.add = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo",
|
||||
torch.jit.script(ModuleAdd()),
|
||||
torch.jit.script(ModuleAdd()), # noqa: F821
|
||||
compile_spec,
|
||||
)
|
||||
self.sub = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo",
|
||||
torch.jit.script(ModuleAdd()),
|
||||
torch.jit.script(ModuleAdd()), # noqa: F821
|
||||
compile_spec,
|
||||
)
|
||||
|
||||
|
@ -715,7 +715,7 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.module = CompModule()
|
||||
self.module = CompModule() # noqa: F821
|
||||
self.scripted_module = torch.jit.script(self.module)
|
||||
buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
|
@ -747,7 +747,7 @@ class AddedAttributesTest(JitBackendTestCase):
|
|||
input = [(torch.ones(5),)]
|
||||
pre_bundled = self.lowered_module(*input[0])
|
||||
# Attach bundled inputs which adds several attributes and functions to the model
|
||||
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input)
|
||||
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input) # noqa: F821
|
||||
post_bundled = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
|
||||
# Save and load the lowered module.
|
||||
self.save_load()
|
||||
|
|
|
@ -90,7 +90,7 @@ class TestBuiltins(JitTestCase):
|
|||
def fn(x):
|
||||
a = x ** 2
|
||||
del a
|
||||
return a
|
||||
return a # noqa: F821
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
|
||||
@torch.jit.script
|
||||
|
@ -104,7 +104,7 @@ class TestBuiltins(JitTestCase):
|
|||
@torch.jit.script
|
||||
def fn(x):
|
||||
a = x ** 2
|
||||
del b
|
||||
del b # noqa: F821
|
||||
return a
|
||||
|
||||
def test_del_multiple_operands(self):
|
||||
|
|
|
@ -49,9 +49,9 @@ class TestException(TestCase):
|
|||
def foo(cond):
|
||||
a = 3
|
||||
if bool(cond):
|
||||
raise ArbitraryError(a, "hi")
|
||||
raise ArbitraryError(a, "hi") # noqa: F821
|
||||
if 1 == 2:
|
||||
raise ArbitraryError
|
||||
raise ArbitraryError # noqa: F821
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
||||
|
|
|
@ -646,7 +646,7 @@ class TestFreezing(JitTestCase):
|
|||
self.assertFalse(mf.hasattr('a'))
|
||||
self.assertTrue(mf.hasattr('b'))
|
||||
with self.assertRaisesRegex(AttributeError, "TestModule (.*) does not have a field with name '_forward'"):
|
||||
mf._forward(x)
|
||||
mf._forward(x) # noqa: F821
|
||||
|
||||
def test_freeze_module_with_inplace_mutable(self):
|
||||
class FreezeMe(torch.jit.ScriptModule):
|
||||
|
|
|
@ -2629,7 +2629,7 @@ class TestScriptList(JitTestCase):
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.value == limit:
|
||||
if self.value == limit: # noqa: F821
|
||||
raise StopIteration()
|
||||
|
||||
ret = self.value
|
||||
|
|
|
@ -280,7 +280,7 @@ class TestMisc(JitTestCase):
|
|||
self.checkScript(foo, ())
|
||||
|
||||
def annotated_list_fail():
|
||||
return expects_intlist(torch.jit.annotate([], List[Tensor]))
|
||||
return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.jit.script(annotated_list_fail)
|
||||
|
|
|
@ -76,7 +76,7 @@ class TestRecursiveScript(JitTestCase):
|
|||
|
||||
def test_failed_function_compilation(self):
|
||||
def fn(x):
|
||||
return i_dont_exist
|
||||
return i_dont_exist # noqa: F821
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, fn):
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, Tensor
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
|
|
|
@ -1911,7 +1911,7 @@ class TestTracer(JitTestCase):
|
|||
|
||||
def test_non_tensor_tracing(self):
|
||||
def f(x):
|
||||
return x + param
|
||||
return x + param # noqa: F821
|
||||
with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"):
|
||||
torch.jit.trace(f, (1,))
|
||||
|
||||
|
|
|
@ -382,7 +382,7 @@ class TestTyping(JitTestCase):
|
|||
@torch.jit.script
|
||||
def outer_scope_cannot_access_comprehension_variables():
|
||||
d = {i : chr(i + 65) for i in range(4)}
|
||||
i = i + 1
|
||||
i = i + 1 # noqa: F821
|
||||
|
||||
def test_for_tuple_assign(self):
|
||||
def test_simple_assign(x):
|
||||
|
@ -596,7 +596,7 @@ class TestTyping(JitTestCase):
|
|||
def test_namedtuple_error_source_attribution(self):
|
||||
class _NamedTupleBadMemberType(NamedTuple):
|
||||
f1: torch.Tensor
|
||||
f2: "ABadForwardRefType"
|
||||
f2: "ABadForwardRefType" # noqa: F821
|
||||
|
||||
make_global(_NamedTupleBadMemberType) # see [local resolution in python]
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch import float32, float16
|
||||
import torch._lazy
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import torch
|
||||
import torch.utils.bundled_inputs
|
||||
import io
|
||||
from tempfile import TemporaryFileName
|
||||
from typing import Dict, List
|
||||
import inspect
|
||||
from torch.testing import FileCheck
|
||||
|
|
|
@ -100,4 +100,4 @@ class TestLiteFuseFx(QuantizationLiteTestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
run_tests() # noqa: F821
|
||||
|
|
|
@ -13609,7 +13609,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
self.bano1 = torch_geometric_nn.BatchNorm(512)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.dense1 = torch.nn.Seq(Lin(512, 1))
|
||||
self.dense1 = torch.nn.Seq(Lin(512, 1)) # noqa: F821
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, coords0, coords1, edge_from, edge_to):
|
||||
|
|
|
@ -66,7 +66,7 @@ b=8, k=2
|
|||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
|
||||
print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -77,9 +77,9 @@ b=4, k=2
|
|||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
|
||||
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test) # noqa: F821
|
||||
print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,7 @@ train_batch_size = 30
|
|||
eval_batch_size = 50
|
||||
|
||||
data_loader, data_loader_test = prepare_data_loaders(data_path)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = nn.CrossEntropyLoss() # noqa: F821
|
||||
float_model = resnet18(pretrained=True)
|
||||
float_model.eval()
|
||||
|
||||
|
@ -26,8 +26,8 @@ model_to_quantize.eval()
|
|||
Prepare model QAT for specified qconfig for torch.nn.Linear
|
||||
"""
|
||||
def prepare_qat_linear(qconfig):
|
||||
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
|
||||
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
|
||||
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]} # noqa: F821
|
||||
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers # noqa: F821
|
||||
training_loop(prepared_model, criterion, data_loader)
|
||||
prepared_model.eval()
|
||||
return prepared_model
|
||||
|
@ -37,7 +37,7 @@ Prepare model with uniform activation, uniform weight
|
|||
b=8, k=2
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(uniform_qconfig_8bit)
|
||||
prepared_model = prepare_qat_linear(uniform_qconfig_8bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -47,7 +47,7 @@ Prepare model with uniform activation, uniform weight
|
|||
b=4, k=2
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(uniform_qconfig_4bit)
|
||||
prepared_model = prepare_qat_linear(uniform_qconfig_4bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -57,7 +57,7 @@ Prepare model with uniform activation, APoT weight
|
|||
(b=8, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(apot_weights_qconfig_8bit)
|
||||
prepared_model = prepare_qat_linear(apot_weights_qconfig_8bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #2 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -67,7 +67,7 @@ Prepare model with uniform activation, APoT weight
|
|||
(b=4, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(apot_weights_qconfig_4bit)
|
||||
prepared_model = prepare_qat_linear(apot_weights_qconfig_4bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #2 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -78,7 +78,7 @@ Prepare model with APoT activation and weight
|
|||
(b=8, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(apot_qconfig_8bit)
|
||||
prepared_model = prepare_qat_linear(apot_qconfig_8bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #3 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
@ -88,7 +88,7 @@ Prepare model with APoT activation and weight
|
|||
(b=4, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_qat_linear(apot_qconfig_4bit)
|
||||
prepared_model = prepare_qat_linear(apot_qconfig_4bit) # noqa: F821
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print(f"Model #3 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
from typing import Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -2560,7 +2560,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
}
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
||||
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # noqa: F821
|
||||
self.assertTrue(
|
||||
'Expected qconfig_dict to have the following keys:' in str(context.exception)
|
||||
)
|
||||
|
@ -9315,8 +9315,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
|
||||
if mode == 'ddp':
|
||||
mp.spawn(run_ddp,
|
||||
args=(world_size, prepared),
|
||||
nprocs=world_size,
|
||||
args=(world_size, prepared), # noqa: F821
|
||||
nprocs=world_size, # noqa: F821
|
||||
join=True)
|
||||
elif mode == 'qat':
|
||||
assert prepared.training, 'prepared must be in training mode for qat'
|
||||
|
@ -9361,8 +9361,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
# calibration
|
||||
if mode == 'ddp':
|
||||
mp.spawn(run_ddp,
|
||||
args=(world_size, qeager),
|
||||
nprocs=world_size,
|
||||
args=(world_size, qeager), # noqa: F821
|
||||
nprocs=world_size, # noqa: F821
|
||||
join=True)
|
||||
elif mode == 'qat':
|
||||
assert qeager.training, 'qeager should be in training mode for qat'
|
||||
|
@ -9526,8 +9526,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
def test_resnet18_ddp(self):
|
||||
from torchvision import models
|
||||
from torchvision.models import quantization as quantized_models
|
||||
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float()
|
||||
model = models.__dict__[name](pretrained=False).eval().float()
|
||||
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() # noqa: F821
|
||||
model = models.__dict__[name](pretrained=False).eval().float() # noqa: F821
|
||||
self._test_model_impl(
|
||||
'ddp', 'resnet18', model, eager_quantizable_model)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
|
@ -249,7 +250,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
|||
eps=2**-12
|
||||
),
|
||||
)
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
|
||||
MinMaxObserver
|
||||
)
|
||||
|
||||
|
@ -266,7 +267,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
|||
),
|
||||
)
|
||||
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
|
||||
PlaceholderObserver
|
||||
)
|
||||
bias_quantization_spec = QuantizationSpec(
|
||||
|
|
|
@ -436,7 +436,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
) -> Tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
len(obs_or_fqs) == 2
|
||||
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fq)}"
|
||||
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
||||
act_obs_or_fq = obs_or_fqs[0]
|
||||
weight_obs_or_fq = obs_or_fqs[1]
|
||||
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
|
||||
|
@ -539,7 +539,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
) -> Tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
len(obs_or_fqs) == 1
|
||||
), f"Expecting one weight obs/fq, got: {len(obs_or_fq)}"
|
||||
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
||||
weight_obs_or_fq = obs_or_fqs[0]
|
||||
(
|
||||
weight_scale,
|
||||
|
|
|
@ -2565,7 +2565,7 @@ exit(2)
|
|||
if not TEST_CUDAMALLOCASYNC:
|
||||
# These stat checks are specific to the native allocator.
|
||||
if share_mem != "Don't share":
|
||||
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
|
||||
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"], # noqa: F821
|
||||
kSmallBuffer)
|
||||
else:
|
||||
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]
|
||||
|
|
|
@ -696,7 +696,7 @@ for test_param in supported_tests:
|
|||
if hasattr(TestExpandedWeightModule, test_name_multi_input):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if decorator is not None:
|
||||
fn = decorator(fn)
|
||||
fn = decorator(fn) # noqa: F821
|
||||
if test.test_cpu:
|
||||
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
|
||||
setattr(TestExpandedWeightModule, test_name_multi_input,
|
||||
|
|
|
@ -71,7 +71,7 @@ class ForeachFuncWrapper:
|
|||
|
||||
class InplaceForeachVersionBumpCheck:
|
||||
|
||||
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None:
|
||||
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None: # noqa: F821
|
||||
self._testcase = testcase
|
||||
self._tensorlist = tensorlist
|
||||
self._orig_version_counts = [t._version for t in tensorlist]
|
||||
|
|
|
@ -50,7 +50,7 @@ from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
|
|||
|
||||
from fx.test_gradual_type import AnnotationsTest # noqa: F401
|
||||
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
|
||||
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
|
@ -1875,13 +1875,13 @@ class TestFX(JitTestCase):
|
|||
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(n)
|
||||
return super().call_function(n) # noqa: F821
|
||||
|
||||
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
||||
if target == 'neg':
|
||||
call_self, *args_tail = args
|
||||
return call_self.sigmoid(*args_tail, **kwargs)
|
||||
return super().call_method(n)
|
||||
return super().call_method(n) # noqa: F821
|
||||
|
||||
input = torch.randn(3, 4)
|
||||
result = NegSigmSwapInterpreter(gm).run(input)
|
||||
|
@ -1990,13 +1990,13 @@ class TestFX(JitTestCase):
|
|||
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(n)
|
||||
return super().call_function(n) # noqa: F821
|
||||
|
||||
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
||||
if target == 'neg':
|
||||
call_self, *args_tail = args
|
||||
return call_self.sigmoid(*args_tail, **kwargs)
|
||||
return super().call_method(n)
|
||||
return super().call_method(n) # noqa: F821
|
||||
|
||||
transformed = NegSigmSwapXformer(gm).transform()
|
||||
input = torch.randn(3, 4)
|
||||
|
|
|
@ -8163,7 +8163,7 @@ dedent """
|
|||
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
return pyfunc2(a) + pyfunc(a)
|
||||
return pyfunc2(a) + pyfunc(a) # noqa: F821
|
||||
|
||||
inputs = self._make_scalar_vars([1], torch.float)
|
||||
outputs = self._make_scalar_vars([6], torch.float)
|
||||
|
|
|
@ -796,7 +796,7 @@ class TestFuser(JitTestCase):
|
|||
FileCheck.check("FusionGroup").run(str(graph))
|
||||
except RuntimeError as e:
|
||||
if 'Failed to compile' in e.args[0]:
|
||||
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
|
||||
warnings.warn('CPU fuser test has failed! This is not a hard failure, ' # noqa: F821
|
||||
'because the kernels sometimes trigger bugs in compilers '
|
||||
'(most notably GCC 7.2).')
|
||||
raise unittest.SkipTest('Failed to compile') from e
|
||||
|
|
|
@ -1289,7 +1289,7 @@ class TestTEFuser(JitTestCase):
|
|||
self.assertLastGraphAllFused()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)])
|
||||
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) # noqa: F821
|
||||
) from e
|
||||
|
||||
def test_isnan(self):
|
||||
|
@ -2706,7 +2706,7 @@ def f({', '.join(param_names)}):
|
|||
return
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', TracerWarning)
|
||||
warnings.simplefilter('ignore', TracerWarning) # noqa: F821
|
||||
self.te_compile(device, dtype, op)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
|
@ -1389,11 +1389,11 @@ class TestAvgPool(TestCaseMPS):
|
|||
return joined_x.view(1, joined_x.numel())
|
||||
|
||||
def _avg_pool2d(self, x, kernel_size):
|
||||
size = reduce(operator.mul, kernel_size)
|
||||
size = reduce(operator.mul, kernel_size) # noqa: F821
|
||||
return self._sum_pool2d(x, kernel_size) / size
|
||||
|
||||
def _avg_pool3d(self, x, kernel_size):
|
||||
size = reduce(operator.mul, kernel_size)
|
||||
size = reduce(operator.mul, kernel_size) # noqa: F821
|
||||
return self._sum_pool3d(x, kernel_size) / size
|
||||
|
||||
def test_avg_pool2d_with_zero_divisor(self):
|
||||
|
@ -6520,7 +6520,7 @@ class TestMPS(TestCaseMPS):
|
|||
devices += ['mps']
|
||||
|
||||
def _gelu_ref(X):
|
||||
return X * stats.norm.cdf(X)
|
||||
return X * stats.norm.cdf(X) # noqa: F821
|
||||
|
||||
for d in devices:
|
||||
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestSortAndSelect(TestCase):
|
|||
# see above
|
||||
return ((b != b) | (a <= b)).all().item()
|
||||
else:
|
||||
error(f'unknown order "{order}", must be "ascending" or "descending"')
|
||||
error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821
|
||||
|
||||
are_ordered = True
|
||||
for k in range(1, SIZE):
|
||||
|
|
|
@ -8924,7 +8924,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
# FIXME: All of the following should be marked as expected failures
|
||||
# so that it is easier to tell when missing has been added.
|
||||
# FIXME: fix all the skipped ones below!
|
||||
test_namespace(torch.randn(1),
|
||||
test_namespace(torch.randn(1), # noqa: F821
|
||||
'as_strided_',
|
||||
re.compile('^clamp_(min|max)_?$'),
|
||||
'is_distributed',
|
||||
|
@ -8946,8 +8946,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
'_autocast_to_fp32',
|
||||
)
|
||||
|
||||
test_namespace(torch.nn)
|
||||
test_namespace(torch.nn.functional, 'assert_int_or_pair')
|
||||
test_namespace(torch.nn) # noqa: F821
|
||||
test_namespace(torch.nn.functional, 'assert_int_or_pair') # noqa: F821
|
||||
# TODO: add torch.* tests when we have proper namespacing on ATen functions
|
||||
# test_namespace(torch)
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@ class TestTransformers(NNTestCase):
|
|||
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
|
||||
except AssertionError as e:
|
||||
continue
|
||||
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
|
||||
self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821
|
||||
|
||||
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
|
||||
encoder.eval()
|
||||
|
@ -240,7 +240,7 @@ class TestTransformers(NNTestCase):
|
|||
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
|
||||
except AssertionError as e:
|
||||
continue
|
||||
self.assertFalse(e, "Failed to catch unsupported Long type exception")
|
||||
self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821
|
||||
|
||||
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
|
||||
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
|
||||
|
@ -917,7 +917,7 @@ class TestTransformers(NNTestCase):
|
|||
)
|
||||
|
||||
if torch_encoder is not None:
|
||||
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
|
||||
self.decoder = torch_to_fairseq(torch_encoder, self.decoder) # noqa: F821
|
||||
self.decoder = self.decoder.eval().cuda().half()
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -17,8 +17,9 @@ import warnings
|
|||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from decimal import Decimal
|
||||
from tempfile import mkstemp
|
||||
|
||||
from unittest import expectedFailure as xfail, skipIf as skipif
|
||||
from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
|
|
@ -258,7 +258,7 @@ class ProcessGroup:
|
|||
def backend(self) -> str: ...
|
||||
@property
|
||||
def _timeout(self) -> timedelta: ...
|
||||
@timeout.setter
|
||||
@_timeout.setter
|
||||
def _timeout(self, val: timedelta) -> None: ...
|
||||
|
||||
class BackendType(Enum):
|
||||
|
@ -507,11 +507,11 @@ class ProcessGroupNCCL(ProcessGroup):
|
|||
def backend(self) -> str: ...
|
||||
@property
|
||||
def _timeout(self) -> timedelta: ...
|
||||
@timeout.setter
|
||||
@_timeout.setter
|
||||
def _timeout(self, val: timedelta) -> None: ...
|
||||
@property
|
||||
def _is_high_priority_stream(self) -> bool: ...
|
||||
@is_high_priority_stream.setter
|
||||
@_is_high_priority_stream.setter
|
||||
def _is_high_priority_stream(self, val: bool) -> None: ...
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -1467,7 +1467,7 @@ if TYPE_CHECKING:
|
|||
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
|
||||
# Fixup segment_reduce visibility
|
||||
_segment_reduce = segment_reduce
|
||||
del segment_reduce
|
||||
del segment_reduce # noqa: F821
|
||||
|
||||
# Ops not to be exposed in `torch` namespace,
|
||||
# mostly helper ops.
|
||||
|
|
|
@ -1577,6 +1577,9 @@ class BuiltinVariable(VariableTracker):
|
|||
# not broadcastable, can't be compared
|
||||
_unimplemented()
|
||||
tensor_cls = left if isinstance(left, TensorVariable) else right
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", op, (left.as_proxy(), right.as_proxy()), {}
|
||||
)
|
||||
return wrap_fx_proxy_cls(
|
||||
type(tensor_cls), # handle Ndarrays and Tensors
|
||||
tx,
|
||||
|
|
|
@ -2,7 +2,7 @@ import functools
|
|||
import inspect
|
||||
import itertools
|
||||
import types
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -13,6 +13,9 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
|||
from ..utils import get_first_attr, make_cell
|
||||
from .base import typestr, VariableTracker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._guards import Source
|
||||
|
||||
|
||||
def wrap_bound_arg(tx, val, source=None):
|
||||
# Source propagation is best effort since not every object we encounter has a source to begin with.
|
||||
|
|
|
@ -3,7 +3,7 @@ MAX_CYCLE = 3000
|
|||
import itertools
|
||||
import operator
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..exc import unimplemented
|
||||
|
|
|
@ -1149,8 +1149,8 @@ def aot_export_joint_simple(
|
|||
|
||||
if config.debug_assert:
|
||||
# Smoke test that after partitioning, we can run the forward without any calling convention changes.
|
||||
fw_module, bw_module = aot_config.default_partition(
|
||||
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
|
||||
fw_module, bw_module = aot_config.default_partition( # noqa: F821
|
||||
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821
|
||||
)
|
||||
# Attempt to run the fw_module with the original user inputs
|
||||
fake_mode = detect_fake_mode(args)
|
||||
|
|
|
@ -14,7 +14,7 @@ import itertools
|
|||
import sympy
|
||||
from collections import defaultdict
|
||||
from torch.fx.passes import graph_drawer
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
from .compile_utils import fx_graph_cse, get_aten_target
|
||||
from . import config
|
||||
import functools
|
||||
|
@ -325,7 +325,8 @@ def _size_of(node: fx.Node) -> int:
|
|||
# Only needed since we don't always trace with fake tensors.
|
||||
if 'tensor_meta' in node.meta:
|
||||
metadata = node.meta['tensor_meta']
|
||||
numel = _prod(map(to_size_hint, metadata.shape))
|
||||
# TODO: What is to_size_hint suppose to be?
|
||||
numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821
|
||||
dtype = metadata.dtype
|
||||
else:
|
||||
return 0
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
@ -40,6 +41,9 @@ from ..utils import (
|
|||
)
|
||||
from ..virtualized import ops, OpsValue, V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import TensorBox
|
||||
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
|
||||
|
||||
|
@ -1288,7 +1292,7 @@ class ChoiceCaller:
|
|||
def hash_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def output_node(self) -> "TensorBox": # type: ignore[name-defined]
|
||||
def output_node(self) -> "TensorBox":
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ def gen_ops() -> List[Any]:
|
|||
|
||||
def dtype_match(
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
|
||||
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
|
||||
) -> bool:
|
||||
# Import cutlass python scripts.
|
||||
assert try_import_cutlass()
|
||||
|
|
|
@ -264,7 +264,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
return res
|
||||
|
||||
@staticmethod
|
||||
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
|
||||
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
|
@ -277,8 +277,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
|
||||
@staticmethod
|
||||
def flip_cutlass_layout(
|
||||
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
|
||||
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
|
||||
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
|
||||
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
|
@ -312,7 +312,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
return result
|
||||
|
||||
@staticmethod
|
||||
def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined]
|
||||
def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821
|
||||
"""
|
||||
returns True if the op is capable of flexible epilogue fusions
|
||||
using epilogue visitor trees.
|
||||
|
@ -345,7 +345,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
|
||||
def define_gemm_instance(
|
||||
self,
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
output_buffer_name: str,
|
||||
epilogue_nodes: Optional[List[IRNode]] = None,
|
||||
) -> Tuple[str, str]:
|
||||
|
@ -408,8 +408,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
|
||||
@staticmethod
|
||||
def swap_XW(
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
|
||||
# Swap X and W in GemmOperation.
|
||||
new_op = copy.deepcopy(op)
|
||||
new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout)
|
||||
|
@ -421,8 +421,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
|
||||
def filter_op(
|
||||
self,
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
|
@ -508,7 +508,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
return None
|
||||
return op
|
||||
|
||||
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
|
||||
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
@ -619,7 +619,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||
def render( # type: ignore[override]
|
||||
self,
|
||||
kernel: CUDATemplateKernel,
|
||||
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined]
|
||||
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
|
||||
template_buffer_node: Optional[CUDATemplateBuffer] = None,
|
||||
epilogue_nodes: Optional[List[IRNode]] = None,
|
||||
**kwargs,
|
||||
|
|
|
@ -3312,7 +3312,7 @@ class CUDATemplateBuffer(TemplateBuffer):
|
|||
inputs,
|
||||
make_kernel_render,
|
||||
workspace_size: int,
|
||||
template: "CUDATemplate", # type: ignore[name-defined]
|
||||
template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821
|
||||
):
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
# Global memory (in bytes) needed for this template.
|
||||
|
|
|
@ -375,7 +375,7 @@ def create_submodule_from_subgraph(
|
|||
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
|
||||
# for any fusion pattern which has them for a node that is not the
|
||||
# first node.
|
||||
cur_args_copy = [cur_node_copy] # type: ignore[has-type]
|
||||
cur_args_copy = [cur_node_copy] # type: ignore[has-type] # noqa: F821
|
||||
|
||||
if len(cur_node_orig.args) > 1:
|
||||
for arg in cur_node_orig.args[1:]:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
@ -26,6 +26,8 @@ from .utils import (
|
|||
get_aten_graph_module,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
@ -259,7 +261,7 @@ def _get_folded_quantized_qat_conv_bn_pattern(
|
|||
return _folded_quantized_qat_conv_bn_pattern
|
||||
|
||||
def _has_conv_bias_filter(
|
||||
match: "InternalMatch", # type: ignore[name-defined]
|
||||
match: "InternalMatch",
|
||||
original_graph: Graph,
|
||||
pattern_graph: Graph,
|
||||
) -> bool:
|
||||
|
@ -273,7 +275,7 @@ def _has_conv_bias_filter(
|
|||
raise ValueError("Could not find conv node in matched conv + bn pattern")
|
||||
|
||||
def _no_conv_bias_filter(
|
||||
match: "InternalMatch", # type: ignore[name-defined]
|
||||
match: "InternalMatch",
|
||||
original_graph: Graph,
|
||||
pattern_graph: Graph,
|
||||
) -> bool:
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
@ -44,6 +45,9 @@ from .api import (
|
|||
StateDictType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._flat_param import FlatParamHandle
|
||||
|
||||
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
|
||||
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
|
||||
FSDP_FLATTENED = "_fsdp_flattened"
|
||||
|
|
|
@ -5,12 +5,15 @@ import math
|
|||
import sys
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from .exported_program import ExportedProgram
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
|
||||
|
||||
__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]
|
||||
|
||||
|
@ -118,7 +121,7 @@ class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
|
|||
"""
|
||||
|
||||
# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
|
||||
constraint_range: "StrictMinMaxConstraint" # type: ignore[name-defined]
|
||||
constraint_range: "StrictMinMaxConstraint"
|
||||
# Represent that `constraint_range` is shared with another _ConstraintTarget, which
|
||||
# typically arises because of a specified equality with another dynamic dimension.
|
||||
shared: Optional[_ConstraintTarget] = None
|
||||
|
|
|
@ -15,7 +15,21 @@ from contextlib import contextmanager
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast, Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
@ -44,6 +58,9 @@ from torch._utils_internal import signpost_event
|
|||
|
||||
from torch._logging import LazyString
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.source import TensorPropertySource
|
||||
|
||||
InputList = List
|
||||
DimList = List
|
||||
|
||||
|
@ -2066,7 +2083,7 @@ class ShapeEnv:
|
|||
source: Source,
|
||||
symbolic_context: SymbolicContext
|
||||
) -> List[sympy.Expr]:
|
||||
return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context)
|
||||
return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context) # noqa: F821
|
||||
|
||||
def _produce_dyn_sizes_from_int_tuple(self,
|
||||
tensor_size: Tuple[int],
|
||||
|
|
|
@ -6,9 +6,12 @@ from ._compatibility import compatibility
|
|||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
|
||||
|
||||
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -202,7 +205,7 @@ def replace_pattern_with_filters(
|
|||
gm: GraphModule,
|
||||
pattern: Union[Callable, Graph, GraphModule],
|
||||
replacement: Union[Callable, Graph, GraphModule],
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
"""
|
||||
|
@ -222,7 +225,7 @@ def _replace_pattern(
|
|||
gm: GraphModule,
|
||||
pattern: Union[Callable, Graph, GraphModule],
|
||||
replacement: Union[Callable, Graph, GraphModule],
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
|
||||
|
|
|
@ -2124,7 +2124,7 @@ class Module:
|
|||
if child is not None:
|
||||
child_prefix = prefix + name + '.'
|
||||
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
||||
load(child, child_state_dict, child_prefix)
|
||||
load(child, child_state_dict, child_prefix) # noqa: F821
|
||||
|
||||
# Note that the hook can modify missing_keys and unexpected_keys.
|
||||
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, Union
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
|
@ -13,6 +13,9 @@ import torch.onnx
|
|||
from torch.onnx._internal import _beartype, exporter, io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
|
||||
|
||||
class TorchExport(exporter.FXGraphExtractor):
|
||||
"""Generates a FX GraphModule using torch.export API
|
||||
|
@ -31,7 +34,7 @@ class TorchExport(exporter.FXGraphExtractor):
|
|||
def generate_fx(
|
||||
self,
|
||||
options: exporter.ResolvedExportOptions,
|
||||
model: "ExportedProgram", # type: ignore[name-defined]
|
||||
model: "ExportedProgram", # type: ignore[override]
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
) -> torch.fx.GraphModule:
|
||||
|
|
|
@ -23,17 +23,17 @@ from .lbfgs import LBFGS
|
|||
from . import lr_scheduler
|
||||
from . import swa_utils
|
||||
|
||||
del adadelta
|
||||
del adagrad
|
||||
del adam
|
||||
del adamw
|
||||
del sparse_adam
|
||||
del adamax
|
||||
del asgd
|
||||
del sgd
|
||||
del radam
|
||||
del rprop
|
||||
del rmsprop
|
||||
del optimizer
|
||||
del nadam
|
||||
del lbfgs
|
||||
del adadelta # noqa: F821
|
||||
del adagrad # noqa: F821
|
||||
del adam # noqa: F821
|
||||
del adamw # noqa: F821
|
||||
del sparse_adam # noqa: F821
|
||||
del adamax # noqa: F821
|
||||
del asgd # noqa: F821
|
||||
del sgd # noqa: F821
|
||||
del radam # noqa: F821
|
||||
del rprop # noqa: F821
|
||||
del rmsprop # noqa: F821
|
||||
del optimizer # noqa: F821
|
||||
del nadam # noqa: F821
|
||||
del lbfgs # noqa: F821
|
||||
|
|
|
@ -206,7 +206,7 @@ def run_ddp(rank, world_size, prepared):
|
|||
prepared.to(rank)
|
||||
model_with_ddp = prepared
|
||||
optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
|
||||
train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)
|
||||
train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821
|
||||
ddp_cleanup()
|
||||
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ TestEnvironment.def_flag(
|
|||
include_in_repro=False)
|
||||
# NB: enabled by default unless in an fbcode context.
|
||||
TestEnvironment.def_flag("PRINT_REPRO_ON_FAILURE", env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
|
||||
default=(not IS_FBCODE), include_in_repro=False)
|
||||
default=(not IS_FBCODE), include_in_repro=False) # noqa: F821
|
||||
|
||||
DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||
DEFAULT_SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||
|
@ -760,7 +760,7 @@ parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
|||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
default=_get_test_report_path() if IS_CI else None) # noqa: F821
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
|
@ -1289,7 +1289,7 @@ TestEnvironment.def_flag("TEST_SKIP_FAST", env_var="PYTORCH_TEST_SKIP_FAST")
|
|||
TestEnvironment.def_flag("TEST_WITH_CROSSREF", env_var="PYTORCH_TEST_WITH_CROSSREF")
|
||||
|
||||
TestEnvironment.def_flag("TEST_SKIP_CUDAGRAPH", env_var="PYTORCH_TEST_SKIP_CUDAGRAPH")
|
||||
TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
|
||||
TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( # noqa: F821
|
||||
(torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 11) or
|
||||
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
|
||||
)
|
||||
|
@ -1303,7 +1303,7 @@ if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
|
|||
def skipIfCrossRef(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if TEST_WITH_CROSSREF:
|
||||
if TEST_WITH_CROSSREF: # noqa: F821
|
||||
raise unittest.SkipTest("test doesn't currently with crossref")
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
|
@ -1320,25 +1320,25 @@ TestEnvironment.def_flag("TEST_WITH_TORCHINDUCTOR", env_var="PYTORCH_TEST_WITH_I
|
|||
# AOT_EAGER not tested in ci, useful for debugging
|
||||
TestEnvironment.def_flag("TEST_WITH_AOT_EAGER", env_var="PYTORCH_TEST_WITH_AOT_EAGER")
|
||||
TestEnvironment.def_flag("TEST_WITH_TORCHDYNAMO", env_var="PYTORCH_TEST_WITH_DYNAMO",
|
||||
implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER)
|
||||
implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER) # noqa: F821
|
||||
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
if TEST_WITH_TORCHDYNAMO: # noqa: F821
|
||||
import torch._dynamo
|
||||
# Do not spend time on helper functions that are called with different inputs
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8
|
||||
# Do not log compilation metrics from unit tests
|
||||
torch._dynamo.config.log_compilation_metrics = False
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
if TEST_WITH_TORCHINDUCTOR: # noqa: F821
|
||||
import torch._inductor.config
|
||||
torch._inductor.config.fallback_random = True
|
||||
|
||||
|
||||
def xpassIfTorchDynamo(func):
|
||||
return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
|
||||
return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func) # noqa: F821
|
||||
|
||||
|
||||
def xfailIfTorchDynamo(func):
|
||||
return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
|
||||
return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func # noqa: F821
|
||||
|
||||
|
||||
def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
|
||||
|
@ -1346,14 +1346,14 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
|
|||
if not isinstance(fn, type):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
if TEST_WITH_TORCHDYNAMO: # noqa: F821
|
||||
raise unittest.SkipTest(msg)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
assert(isinstance(fn, type))
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
if TEST_WITH_TORCHDYNAMO: # noqa: F821
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
|
||||
|
@ -1363,7 +1363,7 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
|
|||
return decorator
|
||||
|
||||
def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
|
||||
condition=TEST_WITH_TORCHINDUCTOR):
|
||||
condition=TEST_WITH_TORCHINDUCTOR): # noqa: F821
|
||||
def decorator(fn):
|
||||
if not isinstance(fn, type):
|
||||
@wraps(fn)
|
||||
|
@ -1427,7 +1427,7 @@ def markDynamoStrictTest(cls_or_func=None, nopython=False):
|
|||
|
||||
|
||||
def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
|
||||
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
|
||||
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR) # noqa: F821
|
||||
|
||||
def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
|
||||
def decorator(fn):
|
||||
|
@ -1535,7 +1535,7 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
|
|||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
if TEST_WITH_ROCM: # noqa: F821
|
||||
raise unittest.SkipTest(reason)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
@ -1547,7 +1547,7 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
|
|||
def runOnRocm(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
if TEST_WITH_ROCM: # noqa: F821
|
||||
fn(*args, **kwargs)
|
||||
else:
|
||||
raise unittest.SkipTest("test currently only works on the ROCm stack")
|
||||
|
@ -1567,7 +1567,7 @@ def skipIfRocmVersionLessThan(version=None):
|
|||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
if TEST_WITH_ROCM: # noqa: F821
|
||||
rocm_version = str(torch.version.hip)
|
||||
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
||||
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
||||
|
@ -1811,7 +1811,7 @@ def skipIfTBB(message="This test makes TBB sad"):
|
|||
def slowTest(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not TEST_WITH_SLOW:
|
||||
if not TEST_WITH_SLOW: # noqa: F821
|
||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
|
@ -2177,7 +2177,7 @@ try:
|
|||
verbosity=hypothesis.Verbosity.verbose))
|
||||
|
||||
hypothesis.settings.load_profile(
|
||||
"pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
|
||||
"pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev') # noqa: F821
|
||||
)
|
||||
except ImportError:
|
||||
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
||||
|
@ -2218,10 +2218,10 @@ def check_if_enable(test: unittest.TestCase):
|
|||
|
||||
if any(matches_test(x) for x in slow_tests_dict.keys()):
|
||||
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
||||
if not TEST_WITH_SLOW:
|
||||
if not TEST_WITH_SLOW: # noqa: F821
|
||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||
|
||||
if not IS_SANDCASTLE:
|
||||
if not IS_SANDCASTLE: # noqa: F821
|
||||
should_skip = False
|
||||
skip_msg = ""
|
||||
|
||||
|
@ -2233,11 +2233,11 @@ def check_if_enable(test: unittest.TestCase):
|
|||
"win": IS_WINDOWS,
|
||||
"windows": IS_WINDOWS,
|
||||
"linux": IS_LINUX,
|
||||
"rocm": TEST_WITH_ROCM,
|
||||
"asan": TEST_WITH_ASAN,
|
||||
"dynamo": TEST_WITH_TORCHDYNAMO,
|
||||
"inductor": TEST_WITH_TORCHINDUCTOR,
|
||||
"slow": TEST_WITH_SLOW,
|
||||
"rocm": TEST_WITH_ROCM, # noqa: F821
|
||||
"asan": TEST_WITH_ASAN, # noqa: F821
|
||||
"dynamo": TEST_WITH_TORCHDYNAMO, # noqa: F821
|
||||
"inductor": TEST_WITH_TORCHINDUCTOR, # noqa: F821
|
||||
"slow": TEST_WITH_SLOW, # noqa: F821
|
||||
}
|
||||
|
||||
invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
|
||||
|
@ -2270,7 +2270,7 @@ def check_if_enable(test: unittest.TestCase):
|
|||
" disabled tests are run"
|
||||
raise unittest.SkipTest(skip_msg)
|
||||
|
||||
if TEST_SKIP_FAST:
|
||||
if TEST_SKIP_FAST: # noqa: F821
|
||||
if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
|
||||
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
||||
|
||||
|
@ -2565,7 +2565,7 @@ class TestCase(expecttest.TestCase):
|
|||
test_method = getattr(self, method_name, None)
|
||||
if test_method is not None:
|
||||
# Wraps the tested method if we should do CUDA memory check.
|
||||
if TEST_CUDA_MEM_LEAK_CHECK:
|
||||
if TEST_CUDA_MEM_LEAK_CHECK: # noqa: F821
|
||||
self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
|
||||
# FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
|
||||
if self._do_cuda_memory_leak_check and not IS_WINDOWS:
|
||||
|
@ -2579,7 +2579,7 @@ class TestCase(expecttest.TestCase):
|
|||
if self._ignore_not_implemented_error:
|
||||
self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
|
||||
|
||||
if PRINT_REPRO_ON_FAILURE:
|
||||
if PRINT_REPRO_ON_FAILURE: # noqa: F821
|
||||
env_var_prefix = TestEnvironment.repro_env_var_prefix()
|
||||
try:
|
||||
def _get_rel_test_path(abs_test_path):
|
||||
|
@ -2684,7 +2684,7 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
|||
test_cls = super_run.__self__
|
||||
|
||||
# Are we compiling?
|
||||
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
|
||||
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR # noqa: F821
|
||||
# Is the class strict and compiling?
|
||||
strict_default = False
|
||||
if compiled:
|
||||
|
@ -2716,11 +2716,11 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
|||
else:
|
||||
supress_errors = torch._dynamo.config.suppress_errors
|
||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", supress_errors):
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
if TEST_WITH_TORCHINDUCTOR: # noqa: F821
|
||||
super_run = torch._dynamo.optimize("inductor")(super_run)
|
||||
elif TEST_WITH_AOT_EAGER:
|
||||
elif TEST_WITH_AOT_EAGER: # noqa: F821
|
||||
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
|
||||
elif TEST_WITH_TORCHDYNAMO:
|
||||
elif TEST_WITH_TORCHDYNAMO: # noqa: F821
|
||||
# TorchDynamo optimize annotation
|
||||
super_run = torch._dynamo.optimize("eager", nopython=nopython)(super_run)
|
||||
key = f"{self.__class__.__name__}.{self._testMethodName}"
|
||||
|
@ -2757,7 +2757,7 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
|||
|
||||
def run(self, result=None):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if TEST_WITH_CROSSREF:
|
||||
if TEST_WITH_CROSSREF: # noqa: F821
|
||||
stack.enter_context(CrossRefMode())
|
||||
self._run_custom(
|
||||
result=result,
|
||||
|
@ -4312,7 +4312,7 @@ def load_tests(loader, tests, pattern):
|
|||
set_running_script_path()
|
||||
test_suite = unittest.TestSuite()
|
||||
for test_group in tests:
|
||||
if not DISABLE_RUNNING_SCRIPT_CHK:
|
||||
if not DISABLE_RUNNING_SCRIPT_CHK: # noqa: F821
|
||||
for test in test_group:
|
||||
check_test_defined_in_running_script(test)
|
||||
if test_group._tests:
|
||||
|
@ -4337,7 +4337,7 @@ GRADCHECK_NONDET_TOL = 1e-12
|
|||
TestEnvironment.def_flag("TEST_WITH_SLOW_GRADCHECK", env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK")
|
||||
|
||||
skipIfSlowGradcheckEnv = unittest.skipIf(
|
||||
TEST_WITH_SLOW_GRADCHECK,
|
||||
TEST_WITH_SLOW_GRADCHECK, # noqa: F821
|
||||
"Tests that don't use gradcheck don't need to run on slow_gradcheck CI"
|
||||
)
|
||||
|
||||
|
@ -4353,7 +4353,7 @@ def gradcheck(fn, inputs, **kwargs):
|
|||
"fast_mode": True,
|
||||
}
|
||||
|
||||
if TEST_WITH_SLOW_GRADCHECK:
|
||||
if TEST_WITH_SLOW_GRADCHECK: # noqa: F821
|
||||
default_values["fast_mode"] = False
|
||||
|
||||
for key, value in default_values.items():
|
||||
|
@ -4373,7 +4373,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
|
|||
"fast_mode": True,
|
||||
}
|
||||
|
||||
if TEST_WITH_SLOW_GRADCHECK:
|
||||
if TEST_WITH_SLOW_GRADCHECK: # noqa: F821
|
||||
default_values["fast_mode"] = False
|
||||
|
||||
for key, value in default_values.items():
|
||||
|
@ -4468,7 +4468,7 @@ def skip_but_pass_in_sandcastle(reason):
|
|||
skipping continuously.
|
||||
"""
|
||||
def decorator(func):
|
||||
if not IS_SANDCASTLE:
|
||||
if not IS_SANDCASTLE: # noqa: F821
|
||||
func.__unittest_skip__ = True
|
||||
func.__unittest_skip_why__ = reason
|
||||
return func
|
||||
|
@ -4573,7 +4573,7 @@ def skip_but_pass_in_sandcastle_if(condition, reason):
|
|||
"""
|
||||
def decorator(func):
|
||||
if condition:
|
||||
if IS_SANDCASTLE:
|
||||
if IS_SANDCASTLE: # noqa: F821
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
|
||||
|
@ -4738,7 +4738,7 @@ class TestGradients(TestCase):
|
|||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
|
||||
small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
|
||||
small_inputs_only=TEST_WITH_SLOW_GRADCHECK) # noqa: F821
|
||||
|
||||
for sample in samples:
|
||||
if sample.broadcasts_input and is_inplace(variant):
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.jit
|
|||
import torch.jit._logging
|
||||
import torch.jit.frontend
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import is_iterable_of_tensors
|
||||
from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
|
||||
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
|
|
|
@ -3,7 +3,7 @@ import random
|
|||
import unittest
|
||||
from functools import partial
|
||||
from itertools import chain, product
|
||||
from typing import Iterable, List
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from numpy import inf
|
||||
|
|
Loading…
Reference in New Issue