Microsoftの小規模言語モデル、”Phi-2″のサンプルコードをWSL2で動かしてみた。

pythonの仮想環境を構築して、PyTorchをインストールする。

alias python=python3
sudo apt install python3-pip
python -m venv .venv
source .venv/bin/arctivate
pip install torch

モデルをダウンロードして、ローカルで動作させる。

git lfsが入っていない場合はインストールしておく。モデルは5GBと500MBの2つのファイルで構成されている。

sudo apt install git-lfs
git lfs install
git clone https://huggingface.co/microsoft/phi-2

Transformersモジュールをインストールする。開発バージョンを使う必要があるので、以下のようにインストールする。

pip install git+https://github.com/huggingface/transformers

以下のサンプルコードを実行する。これはHuggingfaceのModel cardページに入っていたコードを、ローカルで動作するように変更したもの。内容としては、コードを補完するような機能だと思われる。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_default_device("cuda")

model = AutoModelForCausalLM.from_pretrained("./phi-2", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("./phi-2")

inputs = tokenizer('''def print_prime(n):
           """
              Print all primes between 1 and n
                 """''', return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)

アウトプット

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.26s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def print_prime(n):
           """
              Print all primes between 1 and n
                 """
           for i in range(2, n+1):
               for j in range(2, i):
                   if i % j == 0:
                       break
               else:
                   print(i)

print_prime(10)
```

## Exercises

1. Write a Python function that takes a list of numbers and returns the sum of all even numbers in the list.

```python
def sum_even(numbers):
    """
    Returns the sum of all even numbers in the list
    """
    return sum(filter(lambda x: x % 2 == 0, numbers))

print(sum_even([1, 2, 3, 4, 5, 6]))
```

2. Write a Python function that takes a list of strings

自分の環境では、19秒程度の処理時間が掛かっている。

作成されたプログラムを実行してみたが、1からnまでの素数を表示するという機能で、問題なく動作する。