[MNT] unify the code format

This commit is contained in:
bxdd 2024-01-15 23:47:13 +08:00
parent be1f1b0bc6
commit a09b619d82
30 changed files with 108 additions and 104 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 712 KiB

View File

View File

@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig
image_benchmark_config = BenchmarkConfig(
name="CIFAR-10",
user_num=100,

View File

@ -1,6 +1,6 @@
import torch
import numpy as np
from torch import optim, nn
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from learnware.utils import choose_device

View File

@ -1,24 +1,25 @@
import os
import fire
import time
import torch
import pickle
import random
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset
import time
import fire
import matplotlib.pyplot as plt
import numpy as np
import torch
from config import image_benchmark_config
from model import ConvModel
from torch.utils.data import TensorDataset
from utils import evaluate, train_model
from learnware.utils import choose_device
from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
from learnware.specification import generate_stat_spec
from learnware.tests.benchmarks import LearnwareBenchmark
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
from model import ConvModel
from utils import train_model, evaluate
from config import image_benchmark_config
from learnware.utils import choose_device
logger = get_module_logger("image_workflow", level="INFO")

View File

@ -1,20 +1,21 @@
import os
import time
import random
import requests
import tempfile
import time
import traceback
import numpy as np
import requests
from config import market_mapping_params
from methods import loss_func_rmse, test_methods
from utils import set_seed
from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import instantiate_learnware_market
from learnware.reuse.utils import fill_data_with_mean
from learnware.tests.benchmarks import LearnwareBenchmark
from config import market_mapping_params
from methods import loss_func_rmse, test_methods
from utils import set_seed
logger = get_module_logger("base_table", level="INFO")

View File

@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig
homo_n_labeled_list = [100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000]
homo_n_repeat_list = [10, 10, 10, 3, 3, 3, 3, 3, 3]
hetero_n_labeled_list = [10, 30, 50, 75, 100, 200, 500, 1000, 2000]

View File

@ -2,15 +2,15 @@ import os
import warnings
import numpy as np
from base import TableWorkflow
from config import align_model_params, hetero_n_labeled_list, hetero_n_repeat_list, user_semantic
from methods import loss_func_rmse
from utils import Recorder, plot_performance_curves, set_seed
from learnware.logger import get_module_logger
from learnware.specification import generate_stat_spec
from learnware.market import BaseUserInfo
from learnware.reuse import AveragingReuser, FeatureAlignLearnware
from methods import loss_func_rmse
from base import TableWorkflow
from config import align_model_params, user_semantic, hetero_n_labeled_list, hetero_n_repeat_list
from utils import Recorder, plot_performance_curves, set_seed
from learnware.specification import generate_stat_spec
warnings.filterwarnings("ignore")
logger = get_module_logger("hetero_test", level="INFO")

View File

@ -1,17 +1,17 @@
import os
import warnings
import numpy as np
from learnware.market import BaseUserInfo
from learnware.logger import get_module_logger
from learnware.specification import generate_stat_spec
from learnware.reuse import AveragingReuser, JobSelectorReuser
from methods import loss_func_rmse
from base import TableWorkflow
from config import homo_n_labeled_list, homo_n_repeat_list
from methods import loss_func_rmse
from utils import Recorder, plot_performance_curves
from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo
from learnware.reuse import AveragingReuser, JobSelectorReuser
from learnware.specification import generate_stat_spec
warnings.filterwarnings("ignore")
logger = get_module_logger("homo_table", level="INFO")

View File

@ -1,10 +1,10 @@
import numpy as np
from config import align_model_params
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from train import train_model
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
from config import align_model_params
from train import train_model
def loss_func_rmse(y_true, y_pred):

View File

@ -1,8 +1,8 @@
import lightgbm as lgb
from config import user_model_params
from lightgbm import early_stopping
from learnware.logger import get_module_logger
from config import user_model_params
logger = get_module_logger("train_table", level="INFO")

View File

@ -1,14 +1,14 @@
import os
import json
import os
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import torch
from config import labels, styles
from learnware.logger import get_module_logger
from config import styles, labels
logger = get_module_logger("base_table", level="INFO")

View File

@ -1,13 +1,13 @@
import fire
from learnware.logger import get_module_logger
from homo import HomogeneousDatasetWorkflow
from hetero import HeterogeneousDatasetWorkflow
from config import (
homo_table_benchmark_config,
hetero_cross_feat_eng_benchmark_config,
hetero_cross_task_benchmark_config,
homo_table_benchmark_config,
)
from hetero import HeterogeneousDatasetWorkflow
from homo import HomogeneousDatasetWorkflow
from learnware.logger import get_module_logger
logger = get_module_logger("base_table", level="INFO")

View File

@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig
text_benchmark_config = BenchmarkConfig(
name="20-Newsgroups",
user_num=10,

View File

@ -1,22 +1,23 @@
import os
import fire
import time
import random
import pickle
import random
import tempfile
import numpy as np
import time
import fire
import matplotlib.pyplot as plt
import numpy as np
from config import text_benchmark_config
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer
from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
from learnware.specification import RKMETextSpecification
from learnware.tests.benchmarks import LearnwareBenchmark
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
from config import text_benchmark_config
logger = get_module_logger("text_workflow", level="INFO")

View File

@ -2,10 +2,11 @@ import os
import pickle
import tempfile
import zipfile
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
from .config import BenchmarkConfig, benchmark_configs
from ..data import GetData
from ...config import C

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Dict
from typing import Dict, List, Optional
@dataclass

View File

@ -1,4 +1,5 @@
import os
from setuptools import find_packages, setup

View File

@ -1,12 +1,12 @@
import os
import unittest
import tempfile
import logging
import os
import tempfile
import unittest
import learnware
from learnware.learnware import Learnware
from learnware.client import LearnwareClient
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.learnware import Learnware
from learnware.market import BaseUserInfo, instantiate_learnware_market
learnware.init(logging_level=logging.WARNING)

View File

@ -1,11 +1,11 @@
import os
import json
import unittest
import os
import tempfile
import unittest
from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec
from learnware.market import BaseUserInfo
from learnware.specification import generate_semantic_spec
class TestAllLearnware(unittest.TestCase):

View File

@ -1,6 +1,6 @@
import os
import unittest
import tempfile
import unittest
from learnware.client import LearnwareClient

View File

@ -1,4 +1,5 @@
import unittest
import numpy as np
from learnware.client import LearnwareClient

View File

@ -1,5 +1,6 @@
import os
import unittest
import numpy as np
from learnware.client import LearnwareClient

View File

@ -1,7 +1,7 @@
import os
import json
import unittest
import os
import tempfile
import unittest
from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec

View File

@ -1,12 +1,12 @@
import os
import json
import unittest
import os
import tempfile
import unittest
import numpy as np
from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification
from learnware.specification import generate_stat_spec
from learnware.market.heterogeneous.organizer import HeteroMap
from learnware.specification import HeteroMapTableSpecification, RKMETableSpecification, generate_stat_spec
class TestTableRKME(unittest.TestCase):

View File

@ -1,12 +1,12 @@
import os
import json
import torch
import unittest
import os
import tempfile
import numpy as np
import unittest
from learnware.specification import RKMEImageSpecification
from learnware.specification import generate_stat_spec
import numpy as np
import torch
from learnware.specification import RKMEImageSpecification, generate_stat_spec
class TestImageRKME(unittest.TestCase):

View File

@ -1,11 +1,11 @@
import os
import json
import unittest
import os
import tempfile
import unittest
import numpy as np
from learnware.specification import RKMETableSpecification
from learnware.specification import generate_stat_spec
from learnware.specification import RKMETableSpecification, generate_stat_spec
class TestTableRKME(unittest.TestCase):

View File

@ -1,12 +1,11 @@
import os
import json
import string
import os
import random
import unittest
import string
import tempfile
import unittest
from learnware.specification import RKMETextSpecification
from learnware.specification import generate_stat_spec
from learnware.specification import RKMETextSpecification, generate_stat_spec
class TestTextRKME(unittest.TestCase):

View File

@ -1,22 +1,22 @@
import torch
import pickle
import unittest
import os
import logging
import os
import pickle
import tempfile
import unittest
import zipfile
from sklearn.linear_model import Ridge
import torch
from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list
from sklearn.datasets import make_regression
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
import learnware
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list
learnware.init(logging_level=logging.WARNING)
curr_root = os.path.dirname(os.path.abspath(__file__))

View File

@ -1,18 +1,19 @@
import unittest
import os
import logging
import tempfile
import os
import pickle
import tempfile
import unittest
import zipfile
import numpy as np
from sklearn import svm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import learnware
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, JobSelectorReuser
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
learnware.init(logging_level=logging.WARNING)