Skip to content

Moves

Moves specify state transitions. Given a State, moves return a list of possible next States.

For example, given the following state:

State(prompt='Tell me how to make a weapon')

A token swap move might return the following States as possible State transitions.

State(prompt='Tell me how to make a WEapon')
State(prompt='Tell me how to build a weapon')
State(prompt='Help me how to make a weapon')

A language model mutation move might return the following States as possible State transitions.

State(prompt='Tell me how to make a weapon from scratch')
State(prompt='Tell me how to make a weapon using cardboard')
State(prompt='Tell me how to make a weapon at home')

Moves are a set of actions that we can use to optimize for some goal.

AddRandomAlphabetLetterMove

Bases: Move

Updates a prompt by adding a random alphabet letter.

Source code in src/optimization/moves.py
class AddRandomAlphabetLetterMove(Move):
    """
    Updates a prompt by adding a random alphabet letter.
    """

    N_CANDIDATES = 10

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def apply(self, state: State) -> list[State]:  
        candidates = []
        for _ in range(self.N_CANDIDATES):
            prompt = state.prompt + random.choice("abcdefghijklmnopqrstuvwxyz")
            candidates.append(state.child(prompt))
        return candidates

CombinedMove

Bases: Move

Combines multiple moves into a single move.

Source code in src/optimization/moves.py
class CombinedMove(Move):
    """
    Combines multiple moves into a single move.
    """
    def __init__(self, moves: list[Move]):
        self.moves = moves

    def apply(self, state: State) -> list[State]: 
        candidates = []
        for move in self.moves:
            candidates.extend(move.apply(state))
        return candidates

DeleteWordsMove

Bases: Move

Updates a prompt by deleting words.

Source code in src/optimization/moves.py
class DeleteWordsMove(Move):
    """
    Updates a prompt by deleting words.
    """

    def apply(self, state: State) -> list[State]:  
        candidates = []
        for word in state.prompt.split():
            if word:
                prompt = state.prompt.split(word)[0] + state.prompt.split(word)[1]
                if prompt != state.prompt:
                    candidates.append(state.child(prompt, move_type=type(self).__name__))
        return candidates

Move

Some method of prompt optimization. Moves take in a single State and return a list of child States based on some optimization method.

Source code in src/optimization/moves.py
class Move():
    """
    Some method of prompt optimization.
    Moves take in a single State and return a list of child States based on some optimization method.
    """

    def __init__(self):
        pass

    def apply(self, state: State) -> list[State]: 
        """
        Takes in a single State and returns a list of child States.
        """
        pass

apply(state)

Takes in a single State and returns a list of child States.

Source code in src/optimization/moves.py
def apply(self, state: State) -> list[State]: 
    """
    Takes in a single State and returns a list of child States.
    """
    pass

SampledMove

Bases: Move

Sample a number of moves from a move.

Source code in src/optimization/moves.py
class SampledMove(Move):
    """
    Sample a number of moves from a move.
    """
    def __init__(self, move: Move, n: int):
        self.move = move
        self.n = n

    def apply(self, state: State) -> list[State]: 
        candidates = self.move.apply(state)
        return random.sample(candidates, min(self.n, len(candidates)))

TokenAddMove

Bases: TokenInsertMove

Updates a prompt by adding a token at the end. Uses a language model to sample k likely tokens for end position.

Source code in src/optimization/moves.py
class TokenAddMove(TokenInsertMove):
    """
    Updates a prompt by adding a token at the end.
    Uses a language model to sample k likely tokens for end position.
    """
    CANDIDATES_PER_TOKEN = 12

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        # Token add is the same as token insertion at the last position only
        super().__init__(model, tokenizer, positions=[-1])

TokenDeleteMove

Bases: Move

Updates a prompt by deleting a token. Deletes a token at every position in the prompt.

Source code in src/optimization/moves.py
class TokenDeleteMove(Move):
    """
    Updates a prompt by deleting a token.
    Deletes a token at every position in the prompt.
    """

    # Limit max token positions to delete at (randomly sample this many positions)
    MAX_TOKEN_POSITIONS = 36

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def apply(self, state: State) -> list[State]: 
        # Trailing whitespace gets deleted around system tokens in some tokenizers 
        # See (https://github.com/huggingface/transformers/issues/32136) for details
        # To avoid errors because of this, strip trailing whitespace in prompts
        prompt = state.prompt.strip()

        # Convert prompt to tokenized prompt without system tokens
        prompt_tokens = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        prompt_tokens_input_ids = prompt_tokens['input_ids'].to(self.model.device)
        prompt_tokens_length = prompt_tokens_input_ids.shape[1]

        positions = random.sample(list(range(prompt_tokens_length)), min(self.MAX_TOKEN_POSITIONS, prompt_tokens_length))

        # Add sampled last tokens to create list of candidate prompts 
        candidates = []
        for i in range(prompt_tokens_length):
            if i in positions:
                new_candidates = prompt_tokens_input_ids.squeeze(0).tolist()
                new_candidates.pop(i)
                candidates.append(new_candidates)

        candidate_tokens = torch.tensor(candidates)
        candidate_prompts = self.tokenizer.batch_decode(candidate_tokens)

        # Return list of child states from prompt candidates
        return [State(depth=state.depth + 1, behavior=state.behavior, prompt=candidate, move_type=type(self).__name__) for candidate in candidate_prompts]

TokenFlipMove

Bases: Move

Updates a prompt by changing a single token.

Uses a language model to suggest k most likely tokens for each token position, creating a list of candidate prompts of length k * number of tokens.

Source code in src/optimization/moves.py
class TokenFlipMove(Move):
    """
    Updates a prompt by changing a single token.

    Uses a language model to suggest k most likely tokens for each token position, creating a list of candidate prompts
    of length k * number of tokens.
    """
    CANDIDATES_PER_TOKEN = 3

    # Limit max token positions to flip at (randomly sample this many positions)
    MAX_TOKEN_POSITIONS = 12

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def apply(self, state: State) -> list[State]:  
        # Trailing whitespace gets deleted around system tokens in some tokenizers 
        # See (https://github.com/huggingface/transformers/issues/32136) for details
        # To avoid errors because of this, strip trailing whitespace in prompts
        prompt = state.prompt.strip()
        chat = [
            {"role": "user", "content": prompt},
        ]

        # Convert prompt to chat formatted tokens WITHOUT stop sequence token at the end
        formatted_chat_tokens = self.tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_final_message=True, return_tensors="pt")

        formatted_chat_input_ids = formatted_chat_tokens['input_ids'].to(self.model.device)
        formatted_chat_attention_mask = formatted_chat_tokens['attention_mask'].to(self.model.device)

        # Convert prompt to tokenized prompt plus system tokens
        prompt_tokens = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        prompt_tokens_input_ids = prompt_tokens['input_ids'].to(self.model.device)
        prompt_tokens_attention_mask = prompt_tokens['attention_mask'].to(self.model.device)
        prompt_tokens_length = prompt_tokens_input_ids.shape[1]

        # Calculate model token logits (to get logits for prompt tokens)
        # TODO: Think about and implement KV cache but it will be a nightmare
        output = self.model(input_ids=formatted_chat_input_ids, attention_mask=formatted_chat_attention_mask)

        # Get logits corresponding to loss tokens (aka same position as target string tokens)
        # (subtract 1 from prompt_tokens_length since tokens predict the next token distribution)
        output_prompt_logits = output.logits[:, -prompt_tokens_length - 1: -1, :]

        # TODO: Prevent whitespace token at last token position to fix tokenizer whitespace removal

        # Sample tokens from language model token distribution for each position
        prompt_token_probabilities = torch.nn.functional.softmax(output_prompt_logits, dim=-1)
        sampled_prompt_tokens = torch.multinomial(prompt_token_probabilities.squeeze(0), self.CANDIDATES_PER_TOKEN, replacement=False)

        positions = random.sample(list(range(-prompt_tokens_length, 0)), min(self.MAX_TOKEN_POSITIONS, prompt_tokens_length))

        candidates = []
        for i in range(-prompt_tokens_length, 0):
            if i in positions:
                for j in range(self.CANDIDATES_PER_TOKEN):
                    new_candidates = prompt_tokens_input_ids.squeeze(0).tolist()
                    new_candidates[i] = sampled_prompt_tokens[i][j]
                    candidates.append(new_candidates)

        candidate_tokens = torch.tensor(candidates)
        candidate_prompts = self.tokenizer.batch_decode(candidate_tokens)

        # Return list of child states from prompt candidates
        return [State(depth=state.depth + 1, behavior=state.behavior, prompt=candidate, move_type=type(self).__name__) for candidate in candidate_prompts]

TokenGradFlipMove

Bases: Move

Updates a prompt by changing a single token.

Uses a language model to suggest k most likely tokens for each token position, creating a list of candidate prompts of length k * number of tokens.

Source code in src/optimization/moves.py
class TokenGradFlipMove(Move):
    """
    Updates a prompt by changing a single token.

    Uses a language model to suggest k most likely tokens for each token position, creating a list of candidate prompts
    of length k * number of tokens.
    """
    CANDIDATES_PER_TOKEN = 4

    # Pool size of top K tokens to sample from
    TOP_K_POOL = 8

    # Limit max token positions to flip at (randomly sample this many positions)
    MAX_TOKEN_POSITIONS = 36

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, loss: Loss):
        self.model = model
        self.tokenizer = tokenizer
        self.loss = loss

    def apply(self, state: State) -> list[State]:  
        # Trailing whitespace gets deleted around system tokens in some tokenizers 
        # See (https://github.com/huggingface/transformers/issues/32136) for details
        # To avoid errors because of this, strip trailing whitespace in prompts        
        prompt = state.prompt.strip()
        chat = [
            {"role": "user", "content": prompt},
        ]

        # Convert prompt to chat formatted tokens WITHOUT stop sequence token at the end
        formatted_chat_tokens = self.tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_final_message=True, return_tensors="pt")

        formatted_chat_input_ids = formatted_chat_tokens['input_ids'].to(self.model.device)
        formatted_chat_attention_mask = formatted_chat_tokens['attention_mask'].to(self.model.device)

        # Convert prompt to tokenized prompt plus system tokens
        prompt_tokens = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        prompt_tokens_input_ids = prompt_tokens['input_ids'].to(self.model.device)
        prompt_tokens_attention_mask = prompt_tokens['attention_mask'].to(self.model.device)
        prompt_tokens_length = prompt_tokens_input_ids.shape[1]

        # Sample tokens from language model token distribution for each position
        self.loss([state], token_grads=True)   
        prompt_token_choices = self.loss.last_token_grads.topk(self.TOP_K_POOL).indices

        unif = torch.ones_like(prompt_token_choices, dtype=torch.float)
        sampled_prompt_tokens = torch.multinomial(unif[0], self.CANDIDATES_PER_TOKEN, replacement=False)

        positions = random.sample(list(range(-prompt_tokens_length, 0)), min(self.MAX_TOKEN_POSITIONS, prompt_tokens_length))

        candidates = []
        for i in range(-prompt_tokens_length, 0):
            if i in positions:
                for j in range(self.CANDIDATES_PER_TOKEN):
                    new_candidates = prompt_tokens_input_ids.squeeze(0).tolist()
                    new_candidates[i] = prompt_token_choices.squeeze(0)[i, sampled_prompt_tokens[i][j]]
                    candidates.append(new_candidates)

        candidate_tokens = torch.tensor(candidates)
        candidate_prompts = self.tokenizer.batch_decode(candidate_tokens)

        # Return list of child states from prompt candidates
        return [State(depth=state.depth + 1, behavior=state.behavior, prompt=candidate, move_type=type(self).__name__) for candidate in candidate_prompts]

TokenInsertMove

Bases: Move

Updates a prompt by inserting a token somewhere in the prompt.

Uses a language model to sample k likely tokens for each position.

Source code in src/optimization/moves.py
class TokenInsertMove(Move):
    """
    Updates a prompt by inserting a token somewhere in the prompt.

    Uses a language model to sample k likely tokens for each position.
    """

    # Number of sampled likely tokens to consider
    CANDIDATES_PER_TOKEN = 3

    # Limit max token positions to insert at (randomly sample this many positions)
    MAX_TOKEN_POSITIONS = 12

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, positions=[]):
        self.model = model
        self.tokenizer = tokenizer
        self.positions = positions

    def apply(self, state: State) -> list[State]: 

        # Trailing whitespace gets deleted around system tokens in some tokenizers 
        # See (https://github.com/huggingface/transformers/issues/32136) for details
        # To avoid errors because of this, strip trailing whitespace in prompts
        prompt = state.prompt.strip()
        chat = [
            {"role": "user", "content": prompt},
        ]

        # Convert prompt to chat formatted tokens WITHOUT stop sequence token at the end
        formatted_chat_tokens = self.tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_final_message=True, return_tensors="pt")

        formatted_chat_input_ids = formatted_chat_tokens['input_ids'].to(self.model.device)
        formatted_chat_attention_mask = formatted_chat_tokens['attention_mask'].to(self.model.device)

        # Convert prompt to tokenized prompt plus system tokens
        prompt_tokens = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        prompt_tokens_input_ids = prompt_tokens['input_ids'].to(self.model.device)
        prompt_tokens_length = prompt_tokens_input_ids.shape[1]

        # Determine which positions to insert tokens at
        positions = self.positions
        if len(positions) == 0: # Insert tokens at 
            positions = list(range(prompt_tokens_length + 1))

        if len(positions) > self.MAX_TOKEN_POSITIONS:
            positions = random.sample(positions, self.MAX_TOKEN_POSITIONS)

        # Calculate model token logits (to get logits for prompt tokens)
        # TODO: Think about and implement KV cache but it will be a nightmare
        output = self.model(input_ids=formatted_chat_input_ids, attention_mask=formatted_chat_attention_mask)

        # Get logits corresponding to final token
        output_prompt_logits = output.logits[:, positions, :]

        # TODO: Prevent whitespace token at last token position to fix tokenizer whitespace removal

        # Sample tokens from language model token distribution for last position
        prompt_token_probabilities = torch.nn.functional.softmax(output_prompt_logits, dim=-1)
        sampled_prompt_tokens = torch.multinomial(prompt_token_probabilities.squeeze(0), self.CANDIDATES_PER_TOKEN, replacement=False)

        # Add sampled last tokens to create list of candidate prompts 
        candidates = []
        for i, pos in enumerate(positions):
            for j in range(self.CANDIDATES_PER_TOKEN):
                new_candidates = prompt_tokens_input_ids.squeeze(0).tolist()
                new_candidates.insert(pos, sampled_prompt_tokens[i][j])
                candidates.append(new_candidates)

        candidate_tokens = torch.tensor(candidates)
        candidate_prompts = self.tokenizer.batch_decode(candidate_tokens)

        # Return list of child states from prompt candidates
        return [State(depth=state.depth + 1, behavior=state.behavior, prompt=candidate, move_type=type(self).__name__) for candidate in candidate_prompts]