Calling sympy.stats.sample from SymPy.jl

Hi Julia discourse – this forum (and Julia) seems so great I thought I would stop lurking.

I’ve been looking at how to interact with sympy.stats using SymPy.jl. The author of the package was kind enough to give some initial tips.
here.

This works great with the exception of the sample function which returns the below. I’m not groking this error – how would I sample from this random variable?

using SymPy, PyCall;
SymPy.PyCall.pyimport_conda("sympy.stats", "sympy")
@vars Z

Z = sympy.stats.Normal("Z", 0, 2)
sympy.stats.E(Z) # returns 0

sympy.stats.sample(Z) # returns the below


/Users/toomey8/.julia/conda/3/lib/python3.8/site-packages/sympy/stats/rv.py:1104: UserWarning:
The return type of sample has been changed to return an iterator
object since version 1.7. For more information see
https://github.com/sympy/sympy/issues/19061
  warnings.warn(filldedent(message))
PyObject <generator object sample_iter.<locals>.return_generator_finite at 0x7fb8b1380f20>

How would I get it to return a sample from the random variable?

Looking at the linked PR, it appears you need to treat the object as an iterator:

julia> py"next"(sympy.stats.sample(Z))
-1.7527355361953159

julia> py"next"(sympy.stats.sample(Z, size=3))
3-element Vector{Float64}:
 -1.0020213130043105
  1.1586894921503135
  0.3896799828295135

Of course, if all you need to do is sample from a normal distribution, you can use Distributions.jl (which should be much faster than calling Python) or Julia’s builtin randn:

julia> using Distributions, BenchmarkTools

julia> @btime py"next"(sympy.stats.sample($Z))
  442.900 μs (15 allocations: 624 bytes)
-0.427252225353338

julia> Zⱼ = Normal(0, 2.0)
Normal{Float64}(μ=0.0, σ=2.0)

julia> @btime rand($Zⱼ)
  8.400 ns (0 allocations: 0 bytes)
0.6076344626588824

julia> @btime 2randn()
  7.900 ns (0 allocations: 0 bytes)
1.4214944284831204

Thank you. This is absolutely perfect. Good to know about doing something similar with the Distributions package too. Appreciated on all counts!

Thanks for pointing this out. A few things for future readers:

  • I was also getting a message about failing to importscipy, so added scipy: SymPy.PyCall.pyimport_conda("scipy", "scipy"). (Not sure this was needed)
  • In place of py"next"(...) you can call collect on the iterator, as in: collect(sympy.stats.sample(Z, numsamples=3)). (This is used elsewhere with python sets, so hopefully is more idiomatic).