TensorRTへの最適化処理で落ちる現象とその解決

dlshogiのnetwork構成をTransformer系の一つであるViTに変えて学習させたのち、onnxファイルの作成まで漕ぎ着けたは良いものの、その後のtensorRTへの最適化でかなり苦戦したため、ここに備忘録として残します。

 

 

つい先日network構成を変えたものを作ったのですが、上記の通りonnx→tensorRTへの最適化でつまづきました。結論から言うと自分の場合は下記が原因だったように思います。

・onnxファイルへのエクスポート時、tensorRTへの最適化を行う実行ファイルであるbuild_onnxがサポート出来ない関数を使っていること

(※断定していないのは、山岡さんが配布しているbuild_onnxの実行ファイルを使っていたため。あの実行ファイルはエラーメッセージが出ずに落ちるため、原因の切り分けがほとんど出来なかった。)

※2021/11/06 追記

現在時点、最新ソースにて山岡さんがbuild_onnxのソースを公開してくださっています。

実行ファイルのビルドから可能なので、自分で一部修正しても良いかもしれないです。

 

 

自分の場合、PolicyValueNetworkにPytorchの関数の1つであるeinsumをforwardで使っていましたが、これがおそらく上記の実行ファイルにおいてはサポートされていませんでした。

また、これは直接的にエラーメッセージとして出ておらず、代わりにdynamic slice云々のエラーが出ていたため、そちらを躍起になって潰したのですが、実は問題なのはそちらではなく別の部分だったということが、解決に時間がかかった原因でした。

 

ただ、これだけだと分かりずらいのでもう少し詳細に書くことにします。

現在、dlshogiのonnxファイルへのconvertプログラムは、デフォルトでopset=9が指定されるようになっていますが、自分が構成したnetworkはopset=9では対応出来ないため、10以降を指定する必要がありました。これはconvert時のエラーとして上がってきていたため分かっていた部分で、ここを回避するためにopset=13で実行することでエラーとならずにonnxファイルが出来上がるのですが、実はこの時既に間違っていたようでした。

それが最初にも言及したeinsumという関数で、

上記でも書いたように、opset=13を指定してしまうと別の箇所で使っているeinsumというテンソル積の計算に使う関数がそのまま通ってしまいます。

ただこの関数はbuild_onnxファイルではサポートされていない(ように見える)ため、opset=13で正しくonnxファイルがエクスポートされていたとしても、その次のtensorRTへの最適化で詰むという事象が発生していました。

このエラーがどうしても取れずに1ケ月くらい粘り続けていたのですが、ある時ふと「ここのeinsum関数を他の形で置き換えれば良いのでは?」と思いついたためダメ元で試してみたところ、すんなりと通ってしまいました。(時間返して欲しい…)

 

BERT等のtransformer系のNetworkを組もうとしている方はここでエラーに悩まされる可能性もあるかと思うので、その際はこの記事が役に立ってくれることを願っています。

(自分みたいに1ヶ月泣きながらひたすらonnx関係でググることになってしまうとあまりに時間がもったいないので......)