diff --git a/mlxtend/feature_selection/exhaustive_feature_selector.py b/mlxtend/feature_selection/exhaustive_feature_selector.py index 523c4ba19..7dc402a7b 100644 --- a/mlxtend/feature_selection/exhaustive_feature_selector.py +++ b/mlxtend/feature_selection/exhaustive_feature_selector.py @@ -169,8 +169,8 @@ class ExhaustiveFeatureSelector(BaseEstimator, MetaEstimatorMixin): def __init__( self, estimator, - min_features=1, - max_features=1, + min_features=None, + max_features=None, print_progress=True, scoring="accuracy", cv=5, @@ -179,11 +179,13 @@ def __init__( clone_estimator=True, fixed_features=None, feature_groups=None, + feature_range = None ): self.estimator = estimator self.min_features = min_features self.max_features = max_features self.pre_dispatch = pre_dispatch + self.feature_range = feature_range # Want to raise meaningful error message if a # cross-validation generator is inputted if isinstance(cv, types.GeneratorType): @@ -401,8 +403,26 @@ def fit(self, X, y, groups=None, **fit_params): # candidates in the following lines are the non-fixed-features candidates # (the fixed features will be added later to each combination) - min_num_candidates = self.min_features - len(self.fixed_features_group_set) - max_num_candidates = self.max_features - len(self.fixed_features_group_set) + + if self.min_features == None and self.max_features == None: + if self.feature_range == None: + min_num_candidates = 1 + max_num_candidates = 1 + elif self.feature_range == "all": + min_num_candidates = 1 + max_num_candidates = X.shape[1] + else: + try: + min_num_candidates = self.feature_range[0] + max_num_candidates = self.feature_range[1] + except ValueError: + raise ValueError( + """feature_range should be of tuple type. First argument should be + minimum number of feature and second argument shoulf be maximum number of feature""" + ) + else: + min_num_candidates = self.min_features - len(self.fixed_features_group_set) + max_num_candidates = self.max_features - len(self.fixed_features_group_set) candidates = chain.from_iterable( combinations(non_fixed_groups, r=i) for i in range(min_num_candidates, max_num_candidates + 1)