MIA-GCL / MVGRL / GCL2 / eval / svm.py
svm.py
Raw
from sklearn.svm import LinearSVC, SVC
from GCL.eval import BaseSKLearnEvaluator


class SVMEvaluator(BaseSKLearnEvaluator):
    def __init__(self, linear=True, params=None):
        if linear:
            self.evaluator = LinearSVC()
        else:
            self.evaluator = SVC()
        if params is None:
            params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
        super(SVMEvaluator, self).__init__(self.evaluator, params)