Gaussian process regression with Turing gets stuck

Hello,

I have an issue where sampling with a Turing model with Gaussian process sometimes gets stuck while printing the following error massage over and over:

┌ Warning: [DynamicPPL] attempt to link a linked vi
└ @ DynamicPPL C:\Users\sheld\.julia\packages\DynamicPPL\1qg3U\src\varinfo.jl:821

It happens only with some data and quite rarely. It seems to happen more often with more data.

Below is a MWE with one data which causes the problem and one which does not.

using Turing
using AbstractGPs
using Random

# MWE script - - - - - - - - - - - - - - - - - - - - - - - -

function example(X, y)
    noise_prior = LogNormal(-2.3, 0.5)
    gp_lengthscale_prior = Gamma(2., 1.)

    model = turing_model(X, y, gp_lengthscale_prior, noise_prior)
    return turing_sample(model)
end

function turing_sample(model)
    rng = Random.MersenneTwister(5555)
    
    warmup = 60
    samples = 200
    
    return Turing.sample(rng, model, NUTS(warmup, 0.65), samples)
end

# Turing model - - - - - - - - - - - - - - - - - - - - - - - -

Turing.@model function turing_model(X, y, gp_lengthscale_prior, noise_prior)
    noise ~ noise_prior
    gp_lengthscale ~ gp_lengthscale_prior
    gp = construct_finite_gp(X, gp_lengthscale, noise)

    y ~ gp
end

# Gaussian process - - - - - - - - - - - - - - - - - - - - - - - -

function construct_finite_gp(X, lengthscale, noise; min_param_val=1e-6)
    (lengthscale < 0.) && throw(ArgumentError("Lengthscale cannot be negative. Got '$lengthscale'."))
    (noise < 0.) && throw(ArgumentError("Noise cannot be negative. Got '$noise'."))

    # for numerical stability
    lengthscale = lengthscale + min_param_val
    noise = noise + min_param_val

    kernel = with_lengthscale(Matern52Kernel(), lengthscale)
    return GP(mean, kernel)(X', noise)
end

# SCRIPT - - - - - - - - - - - - - - - - - - - - - - - -

Xe = [5.977282456760649; 14.024008416431135; 4.714920089507371; 1.4887056207723681; 4.164617393121821; 15.38209256655557; 3.4794920708690125; 6.97052251854235; 11.522599454384048; 4.6973097649736495; 12.904744993787478; 6.260220936565966; 11.572088574947287; 19.1742779435084; 18.270049695303978; 12.480627265866438; 13.856151294082423; 3.1616321947358816; 13.551785876938105; 11.344106987962093; 10.197310386853143; 1.0547597514517504; 0.5378063316883486; 11.234866993420551; 3.4090650446542803; 6.895265328416009; 9.839096790020632; 5.0496933550255685; 0.9200444542036923; 8.36686271435218; 2.919167098841391; 14.15329887009738; 9.983285280140164; 15.253897683141533; 0.578700848418876; 16.989723268811513; 1.4348060863330536; 4.939139720002688; 14.859911590553686; 12.0266196532552; 1.2267224692520262; 7.170571793898142; 11.214481494307247; 4.681995824477632; 13.407139593429083; 2.9521041424754113; 3.130005456636833; 13.53938712126113; 14.49631305195389; 8.410528209868854; 13.60192495564414; 15.169407147695004; 16.621607883141337; 9.460124793983562; 13.483089934769215; 7.4846641933872915; 5.891111058719201; 5.110014600657191; 14.259730066691523; 15.297242020959443; 1.4319839928612677; 4.359314151026461; 19.130138655345004; 8.141685316629793; 3.995629574197479; 19.68167432816219; 11.991906765162037; 1.4732757708264366; 6.942640229922428; 12.660477193972755; 10.242487513691096; 0.2075168520435744; 1.6000715304875168; 1.676227312006997; 16.427298547560063; 0.6341049307769131; 7.121957965252146; 4.670782950643311; 12.554032489961259; 12.527034553363121; 15.161985333327337; 18.724609238891006; 14.97244714266218; 3.332215149880038; 11.926054512795492; 14.756337439493748; 10.384140307975294; 5.152244483546253; 0.39798970022526525; 14.702252846106084; 10.53731600825731; 2.6817721455955823; 9.702298632708676; 8.794328206261245; 3.7301315309398575; 16.698319646692457; 17.560508679315635; 6.914393364419982; 11.004818883339777; 1.5647473861180794; 16.08263612726779; 2.222615434630635; 11.270421708692488; 18.13686478979323; 4.621445194392768; 11.564059334093276; 2.1970436663829895; 14.712695498682148; 6.189475520120018; 11.468424073622622; 0.6731765173017301; 2.4912564503185886; 8.617555553966172; 4.3231706234782274; 9.792621603739935; 18.889619552255862; 19.710423242778518; 12.421121266381142; 8.724437066372644; 2.3181411848298827; 14.230581483981645; 6.6254713585793334; 12.992127656242296; 13.172670124751829; 18.95389921827848; 19.374523980212185; 18.687026497643664; 0.6310197611379165; 18.21364309610821; 4.123670155596506; 15.54348809708339; 7.367341970002146; 14.818535865958893; 3.57389229680654; 12.205478956101437; 4.579939198082226; 18.461402557589174; 2.5109963343299713; 2.3507750089273705; 17.41837194203383; 9.141908890922451; 13.596503123980245; 1.9477149645404834; 13.839379889788193; 10.207540876717163; 14.506081784453237; 9.592740240928004; 3.3265153433964434; 13.875443695208304; 18.06684599027986; 12.750180888477285; 8.859692649841634; 0.23200360659288766; 17.661990033225162; 6.90371181316725; 12.38499405677485; 3.3488713458752173; 12.717784149718536; 1.7072512646581162; 1.3055332401259134; 18.56850260931642; 17.253729684534775; 6.534998507033407; 18.116900108218314; 3.080784475583749; 1.8289978156107178; 14.119071687677373; 16.704638358245457; 3.0989846397998067; 12.701742529653707; 10.799368636070534; 4.153784891501933; 2.3771482765803764; 16.452843129173026; 1.417178371472021; 16.75166581769424; 8.765098367589355; 1.0120266383560672; 6.874714456102962; 15.350130967863556; 10.411857742123356; 0.5398075571756067; 7.7808185117894935; 8.788690107900901; 18.43255295291015; 13.549669503506667; 8.232672373330253; 6.198244001328579; 12.556386840490642; 1.488598684248399; 5.251209122505434; 13.215566934620199; 11.361354473897695; 12.699730763717225; 13.46305736062179; 17.872915307754194; 10.409061210184234; 1.7797242692729265; 1.866541365184995; 6.096462742679294; 12.440563096656291; 9.333333333333334; 12.83893199495071; 12.326893363981357; 9.333333333333334;;];
ye = [1.63607486668512, -3.9299765757042815, -1.6677557931938989, -1.1695162796669427, -0.6182879715016405, 3.8527100027287915, 1.1421167645732273, 0.1290814898016217, -1.6638670673236804, -1.4587297145669933, 2.985685200516301, 1.7975445756084272, -1.159130157437394, 5.524655355016562, 2.546549022900165, 3.648427839106261, -3.5272527523090407, 1.339381840129779, -1.4749196627331584, -2.3747387566734894, 0.12280569900293328, -0.6086331401265381, 0.5693300491898818, -2.6873196763693965, 1.166999740947411, 0.7349911462462724, 1.7916752284937227, -1.271410408433067, -0.29284848678939013, -1.1595606836775407, 1.2561661340604586, -3.8005998752253944, 1.220553387059398, 2.8261188344846855, 0.45879888082542475, -4.606677250067635, -1.0113789908398705, -1.4835650966264822, -0.3516093603025691, 1.6750568605458662, -0.8383230988754284, -0.6538360481994173, -2.7959286629506424, -1.748612740670036, -0.38785471915791414, 1.172805138036895, 1.4144557665888697, -1.456101703268414, -3.3720889237050127, -1.0235561499753565, -1.7088239083337804, 2.0651714488590613, -1.4094536951342522, 2.5008080211451293, -0.9027510688444842, -1.5253676722615228, 1.3185465282508684, -1.0262227986538077, -3.932356673326095, 3.2164573813040227, -1.0352891267949393, -1.2613177819645802, 5.75560531988228, -2.0310012570218694, -0.06601327291342352, -0.7196098764059458, 1.2530729995316494, -1.1974858994843252, 0.22860099899517733, 3.4477662345926032, -0.12806688379101444, 0.8958347277441676, -1.1740155942683626, -1.1156203347228577, 0.7661852004903004, 0.1860293250604337, -0.1936413424573923, -1.5000961387134035, 3.714824189030127, 3.420723585660592, 2.2466541959713853, 6.467609293736399, 0.45497631496262714, 1.2238516739837557, 0.8259040552436111, -1.3619809695031146, -0.9193173621085828, -0.8804819198524767, 0.6419208900488049, -1.9015755153121048, -1.653188050504348, 0.8667616120457065, 2.3639545236283848, 0.714312071458105, 0.686428749103459, -2.166704336861362, -4.9600553635785225, 0.5011397683458038, -3.048955098918502, -1.1558912459987978, 3.5636073233447805, -0.37372839709190964, -2.8265732595604898, 0.7686885461394604, -1.6003518928814042, -1.3306706009606002, -0.5133346232126792, -1.9064693495359677, 1.8278461403052533, -1.7695798765430624, 0.07187151268700795, 0.3725864285064703, 0.11264500263178602, -0.9167303102103824, 2.069201570734917, 6.4351624934508385, -1.0118930558517851, 3.1075492468250197, 0.4948913728389934, -0.2398915020977868, -4.219350130239142, 1.2572428213051308, 2.3635280035325095, 1.3859569503168512, 6.38229999507575, 3.600532448272419, 6.2485378536436595, 0.3616098276451784, 1.9131108827157077, -0.7655580444681584, 4.527827555830796, -1.1126657013753605, -0.8970584867449667, 0.8866643234421595, 2.5862644960553562, -1.5503272830948935, 4.638995363689807, 0.4472560053597254, -0.11347675169233301, -5.460131227503444, 1.9722533370596855, -1.7441599515911423, -0.9369525670299492, -3.4447153227435185, -0.10582241421944669, -3.120263054374272, 2.3713296161855735, 1.4005277354489756, -3.3642910263906876, -0.04522549504459879, 3.2534785424780623, 1.1209248675403674, 0.9095816757561165, -4.180021617989406, 0.5712451344097331, 3.241302519088113, 1.3131391508768244, 3.3889196334601803, -1.0501667709014715, -1.046114971839579, 5.547563122706457, -5.597493493146314, 1.683667658830885, 0.6514608087897962, 1.370895573149018, -1.1395665010795164, -4.224205424716516, -2.034289718948435, 1.3110582899600574, 3.463960260491116, -2.7829784834433, -0.6289514932185051, 0.21978004657067796, 0.4741427617311176, -1.2102111677489318, -2.594091668840032, 0.5029404836945448, -0.5015883339079734, 0.8280991437026622, 3.727479085083741, -1.0945801319578732, 0.6642794957575822, -2.122141138219406, 0.7804292570698632, 4.315666972700411, -1.4540064800515446, -1.7149210875972902, 1.90083049708229, 3.573671536661788, -0.9132472989165579, -0.6042219766719706, 1.0639468593544579, -2.2860655001892733, 3.417726862433142, -0.8599721859734553, -1.9476066310532838, -1.142008403994458, -1.068693326599653, -0.8900064836587227, 1.7117191541180385, 3.3616152068620084, 2.5147703756116058, 3.0061354115478314, 3.2779414887119485, 2.4671336794377408];

Xg = [5.977282456760649; 14.024008416431135; 4.714920089507371; 1.4887056207723681; 4.164617393121821; 15.38209256655557; 3.4794920708690125; 6.97052251854235; 11.522599454384048; 4.6973097649736495; 12.904744993787478; 6.260220936565966; 11.572088574947287; 19.1742779435084; 18.270049695303978; 12.480627265866438; 13.856151294082423; 3.1616321947358816; 13.551785876938105; 11.344106987962093; 10.197310386853143; 1.0547597514517504; 0.5378063316883486; 11.234866993420551; 3.4090650446542803; 6.895265328416009; 9.839096790020632; 5.0496933550255685; 0.9200444542036923; 8.36686271435218; 2.919167098841391; 14.15329887009738; 9.983285280140164; 15.253897683141533; 0.578700848418876; 16.989723268811513; 1.4348060863330536; 4.939139720002688; 14.859911590553686; 12.0266196532552; 1.2267224692520262; 7.170571793898142; 11.214481494307247; 4.681995824477632; 13.407139593429083; 2.9521041424754113; 3.130005456636833; 13.53938712126113; 14.49631305195389; 8.410528209868854; 13.60192495564414; 15.169407147695004; 16.621607883141337; 9.460124793983562; 13.483089934769215; 7.4846641933872915; 5.891111058719201; 5.110014600657191; 14.259730066691523; 15.297242020959443; 1.4319839928612677; 4.359314151026461; 19.130138655345004; 8.141685316629793; 3.995629574197479; 19.68167432816219; 11.991906765162037; 1.4732757708264366; 6.942640229922428; 12.660477193972755; 10.242487513691096; 0.2075168520435744; 1.6000715304875168; 1.676227312006997; 16.427298547560063; 0.6341049307769131; 7.121957965252146; 4.670782950643311; 12.554032489961259; 12.527034553363121; 15.161985333327337; 18.724609238891006; 14.97244714266218; 3.332215149880038; 11.926054512795492; 14.756337439493748; 10.384140307975294; 5.152244483546253; 0.39798970022526525; 14.702252846106084; 10.53731600825731; 2.6817721455955823; 9.702298632708676; 8.794328206261245; 3.7301315309398575; 16.698319646692457; 17.560508679315635; 6.914393364419982; 11.004818883339777; 1.5647473861180794; 16.08263612726779; 2.222615434630635; 11.270421708692488; 18.13686478979323; 4.621445194392768; 11.564059334093276; 2.1970436663829895; 14.712695498682148; 6.189475520120018; 11.468424073622622; 0.6731765173017301; 2.4912564503185886; 8.617555553966172; 4.3231706234782274; 9.792621603739935; 18.889619552255862; 19.710423242778518; 12.421121266381142; 8.724437066372644; 2.3181411848298827; 14.230581483981645; 6.6254713585793334; 12.992127656242296; 13.172670124751829; 18.95389921827848; 19.374523980212185; 18.687026497643664; 0.6310197611379165; 18.21364309610821; 4.123670155596506; 15.54348809708339; 7.367341970002146; 14.818535865958893; 3.57389229680654; 12.205478956101437; 4.579939198082226; 18.461402557589174; 2.5109963343299713; 2.3507750089273705; 17.41837194203383; 9.141908890922451; 13.596503123980245; 1.9477149645404834; 13.839379889788193; 10.207540876717163; 14.506081784453237; 9.592740240928004; 3.3265153433964434; 13.875443695208304; 18.06684599027986; 12.750180888477285; 8.859692649841634; 0.23200360659288766; 17.661990033225162; 6.90371181316725; 12.38499405677485; 3.3488713458752173; 12.717784149718536; 1.7072512646581162; 1.3055332401259134; 18.56850260931642; 17.253729684534775; 6.534998507033407; 18.116900108218314; 3.080784475583749; 1.8289978156107178; 14.119071687677373; 16.704638358245457; 3.0989846397998067; 12.701742529653707; 10.799368636070534; 4.153784891501933; 2.3771482765803764; 16.452843129173026; 1.417178371472021; 16.75166581769424; 8.765098367589355; 1.0120266383560672; 6.874714456102962; 15.350130967863556; 10.411857742123356; 0.5398075571756067; 7.7808185117894935; 8.788690107900901; 18.43255295291015; 13.549669503506667; 8.232672373330253; 6.198244001328579; 12.556386840490642; 1.488598684248399; 5.251209122505434; 13.215566934620199; 11.361354473897695; 12.699730763717225; 13.46305736062179; 17.872915307754194; 10.409061210184234; 1.7797242692729265; 1.866541365184995; 6.096462742679294; 17.96890585653628; 2.9628281183676908; 11.851504138178953; 4.458312645991227; 6.731987624734375;;];
yg = [1.6390858325387445, -3.9240711880010912, -1.86454018280297, -1.2449368117484771, -0.5538277468143176, 3.854021755944533, 1.0342591912685817, 0.5227450901332067, -1.4575849969343706, -1.539905986619357, 3.049448482380489, 1.7233308105825047, -1.3219868550929117, 5.453547417190307, 2.4969027483612187, 3.483965399250513, -3.420921657395973, 1.438633363162867, -1.4673816724779372, -2.4267912502515756, 0.1293059742621729, -0.5863439351496657, 0.524024202179069, -2.731171386831417, 1.2471486706800057, 0.7245088323731348, 2.1230101480607475, -1.2630487442707288, -0.29369676606803335, -1.1637961569810065, 1.181813374574563, -4.015226795812303, 1.1791252321256136, 3.0288442662891657, 0.5302119856230764, -4.544904200575785, -1.3507561931008494, -1.4892011899337079, -0.7067730816292567, 1.6050085300313415, -0.9485161081805294, -0.3676329642271417, -2.8166050957043645, -1.7576086754189393, -0.4212658109839059, 1.4086239241698197, 1.2722894764347437, -1.4916320344220986, -3.2768254567268715, -0.9275826155587327, -1.8318434740162455, 2.2025128394017024, -1.1962946789470175, 2.6741224965587795, -0.9282428023424697, -1.488402954027688, 1.1913582655693191, -1.1461040835246719, -4.174686695344578, 3.282925640692415, -1.1615506200067285, -1.2819073755378279, 5.673939535980798, -2.164225607518202, -0.24046662020212925, -0.6160494598513213, 1.3199187334377425, -1.1392707987434096, 0.5396986569888228, 3.5693486953207745, -0.31131785551753166, 0.9578404762946777, -1.0817982039266394, -0.9496599048571059, 0.6127641245663844, 0.46549080083132366, -0.05183387387356686, -1.5799720396627652, 3.4367709247325604, 3.3719520505600853, 2.1641372369243213, 6.346030873783134, 0.6316330396558133, 1.209196936908671, 0.8967290801401666, -1.3350112569101658, -0.8888652367005211, -0.9454664384355618, 0.7074896172194814, -1.7260205374344793, -1.795621345551233, 0.7268953634561259, 2.137592363738138, 0.6909572065222959, 0.5704677578740222, -2.2094867247084102, -4.938449734258957, 0.41124294680910306, -3.1250987547237434, -1.2083598018282107, 3.6617707318286734, -0.45318359813995795, -2.7636280973117646, 0.892079215087028, -1.4885960963799298, -1.5022544616175675, -0.35960679626358705, -1.5587201457708608, 2.004808613668589, -1.7469601032739492, 0.08193938727961608, 0.409800058961247, -0.313575097950315, -1.0072749750239949, 1.829867749919835, 6.449512583048983, -1.324792820447148, 3.2657605602450968, 0.48222703559541524, -0.22432109929685007, -3.931093098492196, 1.6106789547183167, 2.453537328075927, 1.4016015175822507, 6.324114958309967, 3.5035921243202566, 6.20319029849308, 0.3357021663643007, 1.7792581531315423, -0.5365700582820306, 4.453711728194251, -1.056604438095999, -0.8534860102942292, 0.8279611162960373, 2.5705352451485757, -1.6595582516589935, 4.6073113327002835, 0.34053799044720007, -0.15491360863846917, -5.607202405507818, 2.142119158650048, -1.9237813569759044, -0.785672754331302, -3.1990644464740976, -0.06334213331352366, -3.2424992354712394, 2.5504353338963517, 1.294684276628422, -3.435619702490956, -0.04181628037156118, 3.3550643524999795, 1.0681744300478533, 0.8993020375146389, -4.121042583355506, 0.5825199402003376, 3.3552021833542147, 1.2894706212089098, 3.4049555691242652, -1.1352569361826519, -0.962690813856548, 5.322808618872159, -5.730699170915555, 1.8287791048247213, 0.5971708039698511, 1.3829802847707449, -1.1064215013494811, -4.066199983474573, -2.012840738538371, 1.4129711690101967, 3.3198205646269248, -2.6773198000611425, -0.7586311654253342, 0.03608170133611735, 0.4956996046873562, -0.8734723794070043, -2.6202090301039584, 0.7641837698220239, -0.45264932009389924, 0.8238499889487036, 3.57379478810016, -1.0712037492469086, 0.43813689428632285, -2.0850399168478124, 0.7727819034720036, 4.476144962433045, -1.29847536202929, -1.600008967668149, 1.863537749283198, 3.4922564815077313, -1.1564725051009874, -0.5186017583066849, 0.9636470406847663, -2.2944564236511917, 3.545301742855712, -0.8496197811688057, -2.162628396989053, -1.024489377466067, -1.0178097906910812, -1.048920708549003, 1.8386472579464308, -1.0611976709663076, 1.1926893259569369, 0.6326479201346191, -1.334477900084726, 1.2017567969124485];

error_data = Xe, ye
good_data = Xg, yg

# samples = example(good_data...)
samples = example(error_data...)

I would be very thankful if anyone could help me figure out what might be the cause of this.

When stuck, the sampler ocassionally prints the following warning as well:

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, true, false, true)
└ @ AdvancedHMC C:\Users\sheld\.julia\packages\AdvancedHMC\iWHPQ\src\hamiltonian.jl:47

However, I don’t think this is the reason the sampler gets completely stuck as this warning sometimes occurs when the sampling successfully finishes as well.

Not sure if this is what’s causing your specific problems because I know nothing about Turing, but in your error data you have two unequal measurements at the exact same coordinate value (9.3[more 3s]4), which is defiinitely a problem.

Considering that you’re using a diagonal perturbation that shouldn’t ruin the Cholesky factorization that I assume is happening somewhere internally, but I’m going to guess that the \ell parameter is an inverse range parameter (meaning your Mat'ern kernel looks like e^{-\ell | t - t'| }(1 + \ell|t-t'| + ...) (skipping the proper scaling parameters), and I think in that case there is some intuition for why a sampler might accidentally take that parameter to infinity.

Maybe try removing that repeated value and see what happens?

1 Like

I think @cgeoga is on to something because it seems like whatever value you feed into the likelihood causes NaN gradients:

julia> using ForwardDiff

julia> function f(θ)
           noise, gp_lengthscale = θ
           gp = construct_finite_gp(model.args.X, gp_lengthscale, noise)
           return logpdf(gp, model.args.y)
       end
f (generic function with 1 method)

julia> θ = [1.0, 1.0]
2-element Vector{Float64}:
 1.0
 1.0

julia> ForwardDiff.gradient(f, θ)
2-element Vector{Float64}:
 NaN
 NaN

It might not be completely obvious from the warning because ℓπ is actually a number containing both the value and the gradient of the target/joint distribution, hence isfinite(ℓπ) being false means that at least on of the two is Inf or NaN.

EDIT: And indeed, removing the point @cgeoga is referring to (it’s the last one in the data) makes the code run :+1:

1 Like

I confirm that either removing the identical point from data or fixing the lengthscale as constant and not sampling it resolves the issue.

I thought that since the model assumes non-zero noise, identical points would not cause problems. (Without it the Cholesky factorization would fail.) The problems with the kernel equation did not occur to me.

I guess I will just have to find some way to deal with identical data.

Thank you both for your time and help. I’m marking the problem as solved. :slight_smile:

1 Like

To elaborate slightly on what the issue is, most GP models that I’m aware of, and certainly any that I use, do assume that you’re measuring a function. I would guess your model, ignoring any mathematical details, is to say that your observed data comes from Y(t) = f(t) + \epsilon(t), where f is some unknown function and \epsilon(t) represents some kind of measurement noise and to use a GP to approximate f. But in your case, or at least with how you’ve specified the model in code, you only have one sample path from this process. So at each t_0, you really should only have one value for Y(t_0), or else Y(t) is not a (real-valued) function.

If you had replicates or some kind of related draws that would be different, and so maybe that’s what you actually have in mind. But when your random variables are actually functions, there is an important distinction between when making more and more measurements constitutes observing a single sample path more and more finely versus observing some finite amount of measurement values for multiple sample paths.

Anyways. I apologize if this is something you already knew, and good luck with your modeling!

3 Likes

I wasn’t working on this problem for last couple days, but returned to it today and I don’t think that there is actually an issue in the way the model is defined. Instead, I think there is some problem in the gradient calculation.

If I change the lengthscale slightly the loglikelihood changes as well as shown below, so I think the gradient should exist. (Correct me if I’m wrong on this.)

Loglikelihood Computation
julia> function f(lengthscale)
           noise = 0.1
           gp = construct_finite_gp(Xe, first(lengthscale), noise)
           return logpdf(gp, ye)
       end
f (generic function with 1 method)

julia> f(0.9)
-158.09618544618976

julia> f(1.0)
-178.65913509460518

julia> f(1.1)
-205.45899882684503

Furthermore, with Zygote the gradient calculation actually works:

Gradient Calculation
julia> ForwardDiff.gradient(f, [1.])
1-element Vector{Float64}:
 NaN

julia> ReverseDiff.gradient(f, [1.])
1-element Vector{Float64}:
 NaN

julia> Zygote.gradient(f, [1.])
([-235.96836642578603],)

I don’t know why some AD backends work and some don’t. It may be some numerical issues or a bug in the backend packages?

Anyway, for completness;
The solution is to set the AD backend to Zygote by adding the following line:

Turing.setadbackend(:zygote)

I’ve tested the code and it runs fine with Zygote even with duplicate data points.