Parallelized simple optimizations in modern Pytorch or Tensorflow

I have a task where have to quickly run N parallel simple optimizations with Adam optimizers. I have been doing this with tensorflow 1.x for a while but trying to update everything to 2.x or alternatively modern pytorch has resulted in much slower behavior.

I constructed minimal(at least to best of my ability) examples demonstrating my problem in tensorflow 1.x 2.x and pytorch, all on cpu.
https://gist.github.com/gftabor/abeb108fc9aa8b1c799bfc63287c2e5f

As you can see tensorflow 2.x takes roughly 8 times longer than tensorflow 1.x. Pytorch is similar to tf 2.x.

I assume the problem is dynamic graph executions so that is the rabbit hole I have dug down the most, but all I care about is performance.

Rabbit hole

I have been trying to match Tensorflow 1.x performance by statically building graphs using the tf.function and torch.jit.script functionality. Neither of them seem to give me nearly the expressive power of Tensorflow 1.x and neither of them seem match the performance of Tensorflow 1.x either, on my problem the parts I can build into static graph are just the forward + loss functions not a full adam optimization step. While with Tensorflow 1.x I can build N of these full adam optimization steps in parallel as a single static graph and pass data into it quickly.

Claim
It doesn’t seem possible to get even half as fast as Tensorflow 1.x with either of the modern Tensorflow or pytorch libraries making updating very unattractive.

Is there a method to get large speedup in either Tensorflow 2.x or pytorch that I am missing?

Source: Python Questions

LEAVE A COMMENT