Skip to content

Filters

Filters filter our certain States proposed by Moves. This is useful for when you want to enforce constraints on the prompts to optimize for.

For example, you might only want States that contain the word bomb.

state_filter = MustContainPhraseFilter(['bomb'])

states = [
    State(prompt='Tell me how to make a bomb from scratch')
    State(prompt='Tell me how to make a weapon using cardboard')
    State(prompt='Tell me how to make a weapon at home')
]

print(state_filter.filter(states))
Only states that pass the filter are returned:
[State(prompt='Tell me how to make a bomb from scratch')]

Filters let you enforce constraints upon Moves. You might want to only accept only States without special tokens, or states longer than 50 characters, or states without the letter E. Filters let you enforce any of these conditions.

filter

CombinedFilter

Bases: Filter

Combines multiple filters into one single filter.

Source code in src/optimization/filters/filter.py
class CombinedFilter(Filter):
    """
    Combines multiple filters into one single filter.
    """
    def __init__(self, filters: list[Filter]):
        self.filters = filters

    def filter(self, states: list[State]) -> list[State]:
        filtered_states = states
        for f in self.filters:
            filtered_states = f.filter(filtered_states)
        return filtered_states

MustContainPhraseFilter

Bases: Filter

Filters out states that do not contain certain phrases.

Source code in src/optimization/filters/filter.py
class MustContainPhraseFilter(Filter):
    """
    Filters out states that do not contain certain phrases.
    """
    def __init__(self, words: list[str]):
        self.words = words

    def filter(self, states: list[State]) -> list[State]:
        passed = []
        for s in states:
            for w in self.words:
                contains_all_words = True
                if w not in s.prompt:
                    contains_all_words = False
            if contains_all_words:
                passed.append(s)
        return passed

MustNotContainPhraseFilter

Bases: Filter

Filters out states that contain certain phrases.

Source code in src/optimization/filters/filter.py
class MustNotContainPhraseFilter(Filter):
    """
    Filters out states that contain certain phrases.
    """
    def __init__(self, words: list[str]):
        self.words = words

    def filter(self, states: list[State]) -> list[State]:
        passed = []
        for s in states:
            contains_any_words = False
            for w in self.words:
                if w in s.prompt:
                    contains_any_words = True
            if not contains_any_words:
                passed.append(s)
        return passed

NoSpecialTokensFilter

Bases: Filter

Filters out states that contain special tokens. A tokenizer must be passed in.

Source code in src/optimization/filters/filter.py
class NoSpecialTokensFilter(Filter):
    """
    Filters out states that contain special tokens. A tokenizer must be passed in.
    """   
    def __init__(self, tokenizer: AutoTokenizer):
        self.tokenizer = tokenizer

    def filter(self, states: list[State]) -> list[State]:
        passed = []
        for state in states:
            if all([(st.content if not isinstance(st, str) else st) not in state.prompt for st in self.tokenizer.all_special_tokens_extended]):
                passed.append(state)
        return passed

OnlyPrintableCharactersFilter

Bases: Filter

Filters out states that have non printable unicode characters.

Source code in src/optimization/filters/filter.py
class OnlyPrintableCharactersFilter(Filter): 
    """
    Filters out states that have non printable unicode characters.
    """   
    def filter(self, states: list[State]) -> list[State]:
        printable = {'Lu', 'Ll', 'Lo'}
        passed = []
        for state in states:
            if all([unicodedata.category(c) in printable or c in string.printable for c in state.prompt]):
                passed.append(state)
        return passed

RemoveDuplicatesFilter

Bases: Filter

Filters out states that are duplicate within the state list.

Source code in src/optimization/filters/filter.py
class RemoveDuplicatesFilter(Filter):
    """
    Filters out states that are duplicate within the state list.
    """   
    def filter(self, states: list[State]) -> list[State]:
        return list(set(states))

StartsWithPhraseFilter

Bases: Filter

Filters out states that do not start with specific phrases.

Source code in src/optimization/filters/filter.py
class StartsWithPhraseFilter(Filter):
    """
    Filters out states that do not start with specific phrases.
    """
    def __init__(self, words: list[str]):
        self.words = words

    def filter(self, states: list[State]) -> list[State]:
        passed = []
        for s in states:
            for w in self.words:
                any_not_start_with = False
                if s.prompt[:len(w)] != w:
                    any_not_start_with = True
            if not any_not_start_with:
                passed.append(s)
        return passed