[MNT] unify the code format
This commit is contained in:
parent
be1f1b0bc6
commit
a09b619d82
Binary file not shown.
Before Width: | Height: | Size: 712 KiB |
|
@ -1,6 +1,5 @@
|
|||
from learnware.tests.benchmarks import BenchmarkConfig
|
||||
|
||||
|
||||
image_benchmark_config = BenchmarkConfig(
|
||||
name="CIFAR-10",
|
||||
user_num=100,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from torch import optim, nn
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from learnware.utils import choose_device
|
||||
|
|
|
@ -1,24 +1,25 @@
|
|||
import os
|
||||
import fire
|
||||
import time
|
||||
import torch
|
||||
import pickle
|
||||
import random
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import TensorDataset
|
||||
import time
|
||||
|
||||
import fire
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from config import image_benchmark_config
|
||||
from model import ConvModel
|
||||
from torch.utils.data import TensorDataset
|
||||
from utils import evaluate, train_model
|
||||
|
||||
from learnware.utils import choose_device
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.market import BaseUserInfo, instantiate_learnware_market
|
||||
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.tests.benchmarks import LearnwareBenchmark
|
||||
from learnware.market import instantiate_learnware_market, BaseUserInfo
|
||||
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
|
||||
from model import ConvModel
|
||||
from utils import train_model, evaluate
|
||||
from config import image_benchmark_config
|
||||
from learnware.utils import choose_device
|
||||
|
||||
logger = get_module_logger("image_workflow", level="INFO")
|
||||
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
import os
|
||||
import time
|
||||
import random
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from config import market_mapping_params
|
||||
from methods import loss_func_rmse, test_methods
|
||||
from utils import set_seed
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.market import instantiate_learnware_market
|
||||
from learnware.reuse.utils import fill_data_with_mean
|
||||
from learnware.tests.benchmarks import LearnwareBenchmark
|
||||
|
||||
from config import market_mapping_params
|
||||
from methods import loss_func_rmse, test_methods
|
||||
from utils import set_seed
|
||||
|
||||
logger = get_module_logger("base_table", level="INFO")
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from learnware.tests.benchmarks import BenchmarkConfig
|
||||
|
||||
|
||||
homo_n_labeled_list = [100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000]
|
||||
homo_n_repeat_list = [10, 10, 10, 3, 3, 3, 3, 3, 3]
|
||||
hetero_n_labeled_list = [10, 30, 50, 75, 100, 200, 500, 1000, 2000]
|
||||
|
|
|
@ -2,15 +2,15 @@ import os
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from base import TableWorkflow
|
||||
from config import align_model_params, hetero_n_labeled_list, hetero_n_repeat_list, user_semantic
|
||||
from methods import loss_func_rmse
|
||||
from utils import Recorder, plot_performance_curves, set_seed
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.market import BaseUserInfo
|
||||
from learnware.reuse import AveragingReuser, FeatureAlignLearnware
|
||||
|
||||
from methods import loss_func_rmse
|
||||
from base import TableWorkflow
|
||||
from config import align_model_params, user_semantic, hetero_n_labeled_list, hetero_n_repeat_list
|
||||
from utils import Recorder, plot_performance_curves, set_seed
|
||||
from learnware.specification import generate_stat_spec
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logger = get_module_logger("hetero_test", level="INFO")
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from learnware.market import BaseUserInfo
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.reuse import AveragingReuser, JobSelectorReuser
|
||||
|
||||
from methods import loss_func_rmse
|
||||
from base import TableWorkflow
|
||||
from config import homo_n_labeled_list, homo_n_repeat_list
|
||||
from methods import loss_func_rmse
|
||||
from utils import Recorder, plot_performance_curves
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.market import BaseUserInfo
|
||||
from learnware.reuse import AveragingReuser, JobSelectorReuser
|
||||
from learnware.specification import generate_stat_spec
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logger = get_module_logger("homo_table", level="INFO")
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import numpy as np
|
||||
from config import align_model_params
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.model_selection import train_test_split
|
||||
from train import train_model
|
||||
|
||||
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
|
||||
from config import align_model_params
|
||||
from train import train_model
|
||||
|
||||
|
||||
def loss_func_rmse(y_true, y_pred):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import lightgbm as lgb
|
||||
from config import user_model_params
|
||||
from lightgbm import early_stopping
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
from config import user_model_params
|
||||
|
||||
logger = get_module_logger("train_table", level="INFO")
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import os
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from config import labels, styles
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
from config import styles, labels
|
||||
|
||||
logger = get_module_logger("base_table", level="INFO")
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import fire
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
from homo import HomogeneousDatasetWorkflow
|
||||
from hetero import HeterogeneousDatasetWorkflow
|
||||
from config import (
|
||||
homo_table_benchmark_config,
|
||||
hetero_cross_feat_eng_benchmark_config,
|
||||
hetero_cross_task_benchmark_config,
|
||||
homo_table_benchmark_config,
|
||||
)
|
||||
from hetero import HeterogeneousDatasetWorkflow
|
||||
from homo import HomogeneousDatasetWorkflow
|
||||
|
||||
from learnware.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("base_table", level="INFO")
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from learnware.tests.benchmarks import BenchmarkConfig
|
||||
|
||||
|
||||
text_benchmark_config = BenchmarkConfig(
|
||||
name="20-Newsgroups",
|
||||
user_num=10,
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
import os
|
||||
import fire
|
||||
import time
|
||||
import random
|
||||
import pickle
|
||||
import random
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import fire
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from config import text_benchmark_config
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.logger import get_module_logger
|
||||
from learnware.market import BaseUserInfo, instantiate_learnware_market
|
||||
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
|
||||
from learnware.specification import RKMETextSpecification
|
||||
from learnware.tests.benchmarks import LearnwareBenchmark
|
||||
from learnware.market import instantiate_learnware_market, BaseUserInfo
|
||||
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
|
||||
from config import text_benchmark_config
|
||||
|
||||
logger = get_module_logger("text_workflow", level="INFO")
|
||||
|
||||
|
|
|
@ -2,10 +2,11 @@ import os
|
|||
import pickle
|
||||
import tempfile
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import BenchmarkConfig, benchmark_configs
|
||||
from ..data import GetData
|
||||
from ...config import C
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import learnware
|
||||
from learnware.learnware import Learnware
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.market import instantiate_learnware_market, BaseUserInfo
|
||||
from learnware.learnware import Learnware
|
||||
from learnware.market import BaseUserInfo, instantiate_learnware_market
|
||||
|
||||
learnware.init(logging_level=logging.WARNING)
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import json
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.specification import generate_semantic_spec
|
||||
from learnware.market import BaseUserInfo
|
||||
from learnware.specification import generate_semantic_spec
|
||||
|
||||
|
||||
class TestAllLearnware(unittest.TestCase):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from learnware.client import LearnwareClient
|
||||
from learnware.specification import generate_semantic_spec
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.market.heterogeneous.organizer import HeteroMap
|
||||
from learnware.specification import HeteroMapTableSpecification, RKMETableSpecification, generate_stat_spec
|
||||
|
||||
|
||||
class TestTableRKME(unittest.TestCase):
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from learnware.specification import RKMEImageSpecification
|
||||
from learnware.specification import generate_stat_spec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from learnware.specification import RKMEImageSpecification, generate_stat_spec
|
||||
|
||||
|
||||
class TestImageRKME(unittest.TestCase):
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import json
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from learnware.specification import RKMETableSpecification
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.specification import RKMETableSpecification, generate_stat_spec
|
||||
|
||||
|
||||
class TestTableRKME(unittest.TestCase):
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
import json
|
||||
import string
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
import string
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from learnware.specification import RKMETextSpecification
|
||||
from learnware.specification import generate_stat_spec
|
||||
from learnware.specification import RKMETextSpecification, generate_stat_spec
|
||||
|
||||
|
||||
class TestTextRKME(unittest.TestCase):
|
||||
|
|
|
@ -1,22 +1,22 @@
|
|||
import torch
|
||||
import pickle
|
||||
import unittest
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
from sklearn.linear_model import Ridge
|
||||
|
||||
import torch
|
||||
from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list
|
||||
from sklearn.datasets import make_regression
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
import learnware
|
||||
from learnware.market import instantiate_learnware_market, BaseUserInfo
|
||||
from learnware.market import BaseUserInfo, instantiate_learnware_market
|
||||
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware
|
||||
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
|
||||
from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
|
||||
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
|
||||
|
||||
from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list
|
||||
|
||||
learnware.init(logging_level=logging.WARNING)
|
||||
curr_root = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
import unittest
|
||||
import os
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
from sklearn import svm
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import learnware
|
||||
from learnware.market import instantiate_learnware_market, BaseUserInfo
|
||||
from learnware.market import BaseUserInfo, instantiate_learnware_market
|
||||
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, JobSelectorReuser
|
||||
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
|
||||
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser
|
||||
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
|
||||
|
||||
learnware.init(logging_level=logging.WARNING)
|
||||
|
|
Loading…
Reference in New Issue