using Flux, Statistics
using Flux.Data: DataLoader
using Flux: @epochs
using Flux.Losses: huber_loss, tversky_loss, mse
using Base: @kwdef
using CUDA
function Data_importer()
train_x = [0.0 0.0 0.0 0.0; 9.17825759553822 9.17825759553822 9.17825759553822 9.17825759553822; 208.88979427867432 208.88979427867432 208.88979427867432 208.88979427867432; 86.22968000074032 86.22968000074032 86.22968000074032 86.22968000074032; 13.711736733314817 13.711736733314817 13.711736733314817 13.711736733314817; 8.481996706383539 8.481996706383539 8.481996706383539 8.481996706383539; 8.259028161102316 8.259028161102316 8.259028161102316 8.259028161102316; 8.250346842514638 8.250346842514638 8.250346842514638 8.250346842514638; 0.0 0.0 0.0 0.0; 0.0033602863460109127 0.0033602863460109127 0.0033602863460109127 0.0033602863460109127; 0.003393828186303037 0.003393828186303037 0.003393828186303037 0.003393828186303037; 0.0015699825966405537 0.0015699825966405537 0.0015699825966405537 0.0015699825966405537; 0.0009128876712430804 0.0009128876712430804 0.0009128876712430804 0.0009128876712430804; 0.0005297583769110372 0.0005297583769110372 0.0005297583769110372 0.0005297583769110372; 0.000307524894698159 0.000307524894698159 0.000307524894698159 0.000307524894698159; 0.00017884235549504642 0.00017884235549504642 0.00017884235549504642 0.00017884235549504642; 0.0 0.0 0.0 0.0; 10.25422212800015 10.25422212800015 10.25422212800015 10.25422212800015; 13.98005927420319 13.98005927420319 13.98005927420319 13.98005927420319; 7.633807093482109 7.633807093482109 7.633807093482109 7.633807093482109; 3.985199951372045 3.985199951372045 3.985199951372045 3.985199951372045; 2.035290990628521 2.035290990628521 2.035290990628521 2.035290990628521; 1.0410181509339318 1.0410181509339318 1.0410181509339318 1.0410181509339318; 0.5398957165592184 0.5398957165592184 0.5398957165592184 0.5398957165592184; 4.437870860867349 4.437870860867349 4.437870860867349 4.437870860867349; 3.6290720215197387 3.6290720215197387 3.6290720215197387 3.6290720215197387; 3.3710853614933476 3.3710853614933476 3.3710853614933476 3.3710853614933476; 0.0 0.0 0.0 0.0; 0.8954068246528882 0.8954068246528882 0.8954068246528882 0.8954068246528882; 3.180081533777011 3.180081533777011 3.180081533777011 3.180081533777011; 1.717803176112131 1.717803176112131 1.717803176112131 1.717803176112131; 0.5866063585180025 0.5866063585180025 0.5866063585180025 0.5866063585180025; 0.2003989848710116 0.2003989848710116 0.2003989848710116 0.2003989848710116; 0.07662918471655883 0.07662918471655883 0.07662918471655883 0.07662918471655883; 0.03285020877003695 0.03285020877003695 0.03285020877003695 0.03285020877003695; 0.0 0.0 0.0 0.0; 0.1277455368175085 0.1277455368175085 0.1277455368175085 0.1277455368175085; 0.9041545413730231 0.9041545413730231 0.9041545413730231 0.9041545413730231; 2.015779141630358 2.015779141630358 2.015779141630358 2.015779141630358; 2.233739985634341 2.233739985634341 2.233739985634341 2.233739985634341; 2.2951307932962957 2.2951307932962957 2.2951307932962957 2.2951307932962957; 2.311706429371978 2.311706429371978 2.311706429371978 2.311706429371978; 2.3146253736089615 2.3146253736089615 2.3146253736089615 2.3146253736089615; 0.0 0.0 0.0 0.0; 0.21544701419224335 0.21544701419224335 0.21544701419224335 0.21544701419224335; 0.8643671898900301 0.8643671898900301 0.8643671898900301 0.8643671898900301; 1.0789438343976807 1.0789438343976807 1.0789438343976807 1.0789438343976807; 0.8264001513165987 0.8264001513165987 0.8264001513165987 0.8264001513165987; 0.559729339546755 0.559729339546755 0.559729339546755 0.559729339546755; 0.35624374055504393 0.35624374055504393 0.35624374055504393 0.35624374055504393; 0.21904639090224073 0.21904639090224073 0.21904639090224073 0.21904639090224073; 0.0 0.0 0.0 0.0; 10.83328521954551 10.83328521954551 10.83328521954551 10.83328521954551; 25.497900864526745 25.497900864526745 25.497900864526745 25.497900864526745; 45.652475990325485 45.652475990325485 45.652475990325485 45.652475990325485; 56.670253412735384 56.670253412735384 56.670253412735384 56.670253412735384; 65.55386181740859 65.55386181740859 65.55386181740859 65.55386181740859; 72.40643880100566 72.40643880100566 72.40643880100566 72.40643880100566; 77.51342298440797 77.51342298440797 77.51342298440797 77.51342298440797; 0.0 0.0 0.0 0.0; 0.8967867907308589 0.8967867907308589 0.8967867907308589 0.8967867907308589; 4.918936136105033 4.918936136105033 4.918936136105033 4.918936136105033; 8.292012929977501 8.292012929977501 8.292012929977501 8.292012929977501; 7.254512052922267 7.254512052922267 7.254512052922267 7.254512052922267; 5.451675864658739 5.451675864658739 5.451675864658739 5.451675864658739; 3.7754110203152083 3.7754110203152083 3.7754110203152083 3.7754110203152083; 2.4898798229566546 2.4898798229566546 2.4898798229566546 2.4898798229566546; 0.0 0.0 0.0 0.0; 1.460691882519804 1.460691882519804 1.460691882519804 1.460691882519804; 7.57065251053054 7.57065251053054 7.57065251053054 7.57065251053054; 10.570986968627574 10.570986968627574 10.570986968627574 10.570986968627574; 7.3784653117753916 7.3784653117753916 7.3784653117753916 7.3784653117753916; 4.20130183076954 4.20130183076954 4.20130183076954 4.20130183076954; 2.1695321090483577 2.1695321090483577 2.1695321090483577 2.1695321090483577; 1.0738186433327617 1.0738186433327617 1.0738186433327617 1.0738186433327617; 0.0 0.0 0.0 0.0; 0.5492304854926758 0.5492304854926758 0.5492304854926758 0.5492304854926758; 3.206984127978817 3.206984127978817 3.206984127978817 3.206984127978817; 6.113384121173109 6.113384121173109 6.113384121173109 6.113384121173109; 5.892001304969062 5.892001304969062 5.892001304969062 5.892001304969062; 4.903702489828794 4.903702489828794 4.903702489828794 4.903702489828794; 3.7685121656563942 3.7685121656563942 3.7685121656563942 3.7685121656563942; 2.758515525033501 2.758515525033501 2.758515525033501 2.758515525033501; 0.0 0.0 0.0 0.0; 0.004046802945817767 0.004046802945817767 0.004046802945817767 0.004046802945817767; 0.011756448834050696 0.011756448834050696 0.011756448834050696 0.011756448834050696; 0.007022005639101842 0.007022005639101842 0.007022005639101842 0.007022005639101842; 0.0026274270463931366 0.0026274270463931366 0.0026274270463931366 0.0026274270463931366; 0.0009233212090163512 0.0009233212090163512 0.0009233212090163512 0.0009233212090163512; 0.00035456145851238456 0.00035456145851238456 0.00035456145851238456 0.00035456145851238456; 0.00015178794666672381 0.00015178794666672381 0.00015178794666672381 0.00015178794666672381; 0.0 0.0 0.0 0.0; 3.630838983977378 3.630838983977378 3.630838983977378 3.630838983977378; 10.986059035911685 10.986059035911685 10.986059035911685 10.986059035911685; 11.509242269867258 11.509242269867258 11.509242269867258 11.509242269867258; 8.230758773685777 8.230758773685777 8.230758773685777 8.230758773685777; 4.795391761441194 4.795391761441194 4.795391761441194 4.795391761441194; 2.4472246033065117 2.4472246033065117 2.4472246033065117 2.4472246033065117; 1.1700422208228751 1.1700422208228751 1.1700422208228751 1.1700422208228751; 242.35922163776416 242.35922163776416 242.35922163776416 242.35922163776416; 13.711736733314817 13.711736733314817 13.711736733314817 13.711736733314817; 0.002668334901823087 0.002668334901823087 0.002668334901823087 0.002668334901823087; 0.0009128876712430804 0.0009128876712430804 0.0009128876712430804 0.0009128876712430804; 12.760879764165292 12.760879764165292 12.760879764165292 12.760879764165292; 3.985199951372045 3.985199951372045 3.985199951372045 3.985199951372045; 3.3200931509430185 3.3200931509430185 3.3200931509430185 3.3200931509430185; 0.5866063585180025 0.5866063585180025 0.5866063585180025 0.5866063585180025; 1.3847131486936883 1.3847131486936883 1.3847131486936883 1.3847131486936883; 2.233739985634341 2.233739985634341 2.233739985634341 2.233739985634341; 1.0711448833479906 1.0711448833479906 1.0711448833479906 1.0711448833479906; 0.8264001513165987 0.8264001513165987 0.8264001513165987 0.8264001513165987; 32.62984147688874 32.62984147688874 32.62984147688874 32.62984147688874; 56.670253412735384 56.670253412735384 56.670253412735384 56.670253412735384; 6.867370333654399 6.867370333654399 6.867370333654399 6.867370333654399; 7.254512052922267 7.254512052922267 7.254512052922267 7.254512052922267; 10.11848390214388 10.11848390214388 10.11848390214388 10.11848390214388; 7.3784653117753916 7.3784653117753916 7.3784653117753916 7.3784653117753916; 4.645712717143942 4.645712717143942 4.645712717143942 4.645712717143942; 5.892001304969062 5.892001304969062 5.892001304969062 5.892001304969062; 0.012073508388332998 0.012073508388332998 0.012073508388332998 0.012073508388332998; 0.0026274270463931366 0.0026274270463931366 0.0026274270463931366 0.0026274270463931366; 12.360684318680958 12.360684318680958 12.360684318680958 12.360684318680958; 8.230758773685777 8.230758773685777 8.230758773685777 8.230758773685777; 0.0 0.0 0.0 0.0; 9.183147762109897 9.183147762109897 9.183147762109897 9.183147762109897; 272.89942159723154 272.89942159723154 272.89942159723154 272.89942159723154; 385.66854642422527 385.66854642422527 385.66854642422527 385.66854642422527; 231.60175795245044 231.60175795245044 231.60175795245044 231.60175795245044; 31.882982498484374 31.882982498484374 31.882982498484374 31.882982498484374; 9.263574293776522 9.263574293776522 9.263574293776522 9.263574293776522; 8.289143195524618 8.289143195524618 8.289143195524618 8.289143195524618; 0.0 0.0 0.0 0.0; 0.0033936233269959266 0.0033936233269959266 0.0033936233269959266 0.0033936233269959266; 0.005312994052635443 0.005312994052635443 0.005312994052635443 0.005312994052635443; 0.003455924109309518 0.003455924109309518 0.003455924109309518 0.003455924109309518; 0.0020120100779384194 0.0020120100779384194 0.0020120100779384194 0.0020120100779384194; 0.0011657789723958994 0.0011657789723958994 0.0011657789723958994 0.0011657789723958994; 0.0006761338820777927 0.0006761338820777927 0.0006761338820777927 0.0006761338820777927; 0.00039343572970745356 0.00039343572970745356 0.00039343572970745356 0.00039343572970745356; 0.0 0.0 0.0 0.0; 10.264950158032645 10.264950158032645 10.264950158032645 10.264950158032645; 14.57029950413034 14.57029950413034 14.57029950413034 14.57029950413034; 5.973187880105501 5.973187880105501 5.973187880105501 5.973187880105501; 2.9161221160029323 2.9161221160029323 2.9161221160029323 2.9161221160029323; 1.4223058261310046 1.4223058261310046 1.4223058261310046 1.4223058261310046; 0.6942509524984898 0.6942509524984898 0.6942509524984898 0.6942509524984898; 0.3395684079224356 0.3395684079224356 0.3395684079224356 0.3395684079224356; 0.0 0.0 0.0 0.0; 0.8960371616330409 0.8960371616330409 0.8960371616330409 0.8960371616330409; 3.5558705856950854 3.5558705856950854 3.5558705856950854 3.5558705856950854; 5.083922609618896 5.083922609618896 5.083922609618896 5.083922609618896; 2.6278617480133883 2.6278617480133883 2.6278617480133883 2.6278617480133883; 1.1223757278857793 1.1223757278857793 1.1223757278857793 1.1223757278857793; 0.5372106361174306 0.5372106361174306 0.5372106361174306 0.5372106361174306; 0.2822878251955673 0.2822878251955673 0.2822878251955673 0.2822878251955673; 0.0 0.0 0.0 0.0; 0.17838550815239795 0.17838550815239795 0.17838550815239795 0.17838550815239795; 1.7526366986845667 1.7526366986845667 1.7526366986845667 1.7526366986845667; 6.80854636600655 6.80854636600655 6.80854636600655 6.80854636600655; 9.51762590982706 9.51762590982706 9.51762590982706 9.51762590982706; 10.771204764133687 10.771204764133687 10.771204764133687 10.771204764133687; 11.326657894279872 11.326657894279872 11.326657894279872 11.326657894279872; 11.593266061188158 11.593266061188158 11.593266061188158 11.593266061188158; 0.0 0.0 0.0 0.0; 0.29944597843254145 0.29944597843254145 0.29944597843254145 0.29944597843254145; 1.7367042869727944 1.7367042869727944 1.7367042869727944 1.7367042869727944; 3.7908534496545765 3.7908534496545765 3.7908534496545765 3.7908534496545765; 2.847901261307619 2.847901261307619 2.847901261307619 2.847901261307619; 1.7019718269174764 1.7019718269174764 1.7019718269174764 1.7019718269174764; 0.962944445586322 0.962944445586322 0.962944445586322 0.962944445586322; 0.5362649743551858 0.5362649743551858 0.5362649743551858 0.5362649743551858; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.8971321730692192 0.8971321730692192 0.8971321730692192 0.8971321730692192; 5.186510539291743 5.186510539291743 5.186510539291743 5.186510539291743; 12.382450981116792 12.382450981116792 12.382450981116792 12.382450981116792; 12.661019141023448 12.661019141023448 12.661019141023448 12.661019141023448; 10.32789996044903 10.32789996044903 10.32789996044903 10.32789996044903; 7.486344400036383 7.486344400036383 7.486344400036383 7.486344400036383; 5.076279286195816 5.076279286195816 5.076279286195816 5.076279286195816; 0.0 0.0 0.0 0.0; 1.461376437635734 1.461376437635734 1.461376437635734 1.461376437635734; 8.05319814910555 8.05319814910555 8.05319814910555 8.05319814910555; 17.180043906818515 17.180043906818515 17.180043906818515 17.180043906818515; 13.707201194502158 13.707201194502158 13.707201194502158 13.707201194502158; 8.233012260075533 8.233012260075533 8.233012260075533 8.233012260075533; 4.554732841496146 4.554732841496146 4.554732841496146 4.554732841496146; 2.474768006641207 2.474768006641207 2.474768006641207 2.474768006641207; 0.0 0.0 0.0 0.0; 0.5494356510305276 0.5494356510305276 0.5494356510305276 0.5494356510305276; 3.3711709610048155 3.3711709610048155 3.3711709610048155 3.3711709610048155; 8.745418950316784 8.745418950316784 8.745418950316784 8.745418950316784; 9.134660179840424 9.134660179840424 9.134660179840424 9.134660179840424; 7.536250849781278 7.536250849781278 7.536250849781278 7.536250849781278; 5.627834359550135 5.627834359550135 5.627834359550135 5.627834359550135; 4.003730975043757 4.003730975043757 4.003730975043757 4.003730975043757; 0.0 0.0 0.0 0.0; 0.004093049761109787 0.004093049761109787 0.004093049761109787 0.004093049761109787; 0.013100383971910939 0.013100383971910939 0.013100383971910939 0.013100383971910939; 0.01685973424939562 0.01685973424939562 0.01685973424939562 0.01685973424939562; 0.010419245948019877 0.010419245948019877 0.010419245948019877 0.010419245948019877; 0.0050518885474255745 0.0050518885474255745 0.0050518885474255745 0.0050518885474255745; 0.0025507298567488865 0.0025507298567488865 0.0025507298567488865 0.0025507298567488865; 0.0013732756944053045 0.0013732756944053045 0.0013732756944053045 0.0013732756944053045; 0.0 0.0 0.0 0.0; 3.7305272193429055 3.7305272193429055 3.7305272193429055 3.7305272193429055; 12.650920481778668 12.650920481778668 12.650920481778668 12.650920481778668; 18.950625506685032 18.950625506685032 18.950625506685032 18.950625506685032; 19.391613942969105 19.391613942969105 19.391613942969105 19.391613942969105; 17.274502061451734 17.274502061451734 17.274502061451734 17.274502061451734; 13.680045380285472 13.680045380285472 13.680045380285472 13.680045380285472; 9.941248998688334 9.941248998688334 9.941248998688334 9.941248998688334; 292.77563303392054 292.77563303392054 292.77563303392054 292.77563303392054; 27.214565661145414 27.214565661145414 27.214565661145414 27.214565661145414; 0.0032122739226327617 0.0032122739226327617 0.0032122739226327617 0.0032122739226327617; 0.0011607996841825358 0.0011607996841825358 0.0011607996841825358 0.0011607996841825358; 12.971732976956275 12.971732976956275 12.971732976956275 12.971732976956275; 4.696162195031036 4.696162195031036 4.696162195031036 4.696162195031036; 3.8277241678388316 3.8277241678388316 3.8277241678388316 3.8277241678388316; 1.250457539918555 1.250457539918555 1.250457539918555 1.250457539918555; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 1.1158940143342528 1.1158940143342528 1.1158940143342528 1.1158940143342528; 0.9435484357253695 0.9435484357253695 0.9435484357253695 0.9435484357253695; 32.70960306250701 32.70960306250701 32.70960306250701 32.70960306250701; 57.35067073622422 57.35067073622422 57.35067073622422 57.35067073622422; 7.190978302233153 7.190978302233153 7.190978302233153 7.190978302233153; 8.423633688327191 8.423633688327191 8.423633688327191 8.423633688327191; 10.625780453324678 10.625780453324678 10.625780453324678 10.625780453324678; 8.529421678139203 8.529421678139203 8.529421678139203 8.529421678139203; 4.849845633514303 4.849845633514303 4.849845633514303 4.849845633514303; 6.750887271299424 6.750887271299424 6.750887271299424 6.750887271299424; 0.013401994604077953 0.013401994604077953 0.013401994604077953 0.013401994604077953; 0.005264826898432422 0.005264826898432422 0.005264826898432422 0.005264826898432422; 12.464445239560515 12.464445239560515 12.464445239560515 12.464445239560515; 9.450206038945831 9.450206038945831 9.450206038945831 9.450206038945831]
train_y = [0.47325077443766045 0.47325077443766045 0.47325077443766045 0.47325077443766045; 0.4223366661848868 0.4223366661848868 0.4223366661848868 0.4223366661848868; 0.32794944372177465 0.32794944372177465 0.32794944372177465 0.32794944372177465; 0.3534293242231996 0.3534293242231996 0.3534293242231996 0.3534293242231996; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.35984006970664795 0.35984006970664795 0.35984006970664795 0.35984006970664795; 0.8333333333333334 0.8333333333333334 0.8333333333333334 0.8333333333333334; 0.3986717366142045 0.3986717366142045 0.3986717366142045 0.3986717366142045; 0.4611683330924434 0.4611683330924434 0.4611683330924434 0.4611683330924434; 0.2222222222222222 0.2222222222222222 0.2222222222222222 0.2222222222222222; 0.7487289585955368 0.7487289585955368 0.7487289585955368 0.7487289585955368; 0.35984006970664795 0.35984006970664795 0.35984006970664795 0.35984006970664795; 0.4223366661848868 0.4223366661848868 0.4223366661848868 0.4223366661848868; 0.3112255550737757 0.3112255550737757 0.3112255550737757 0.3112255550737757; 0.6085690283021847 0.6085690283021847 0.6085690283021847 0.6085690283021847; 0.35984006970664795 0.35984006970664795 0.35984006970664795 0.35984006970664795; 0.3667811106293313 0.3667811106293313 0.3667811106293313 0.3667811106293313; 0.35005722198133227 0.35005722198133227 0.35005722198133227 0.35005722198133227; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.3679582939109945 0.3679582939109945 0.3679582939109945 0.3679582939109945; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.3835049992773302 0.3835049992773302 0.3835049992773302 0.3835049992773302; 0.3835049992773302 0.3835049992773302 0.3835049992773302 0.3835049992773302; 0.40791237115678924 0.40791237115678924 0.40791237115678924 0.40791237115678924; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.6278349997591101 0.6278349997591101 0.6278349997591101 0.6278349997591101; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5 0.5 0.5 0.5; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.6543417361324246 0.6543417361324246 0.6543417361324246 0.6543417361324246; 0.5388316669075566 0.5388316669075566 0.5388316669075566 0.5388316669075566; 0.2556699995182201 0.2556699995182201 0.2556699995182201 0.2556699995182201; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.47095118081775905 0.47095118081775905 0.47095118081775905 0.47095118081775905; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.4611683330924434 0.4611683330924434 0.4611683330924434 0.4611683330924434; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5385891777185317 0.5385891777185317 0.5385891777185317 0.5385891777185317; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.6134107047310126 0.6134107047310126 0.6134107047310126 0.6134107047310126; 0.7136165577785698 0.7136165577785698 0.7136165577785698 0.7136165577785698; 0.5 0.5 0.5 0.5; 0.541131260527458 0.541131260527458 0.541131260527458 0.541131260527458; 0.4611683330924434 0.4611683330924434 0.4611683330924434 0.4611683330924434; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5820622919288702 0.5820622919288702 0.5820622919288702 0.5820622919288702; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.5451618531468253 0.5451618531468253 0.5451618531468253 0.5451618531468253; 0.5334477772959979 0.5334477772959979 0.5334477772959979 0.5334477772959979; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.516723888647999 0.516723888647999 0.516723888647999 0.516723888647999; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.49461611038844133 0.49461611038844133 0.49461611038844133 0.49461611038844133; 0.6134107047310126 0.6134107047310126 0.6134107047310126 0.6134107047310126; 0.4557844434808847 0.4557844434808847 0.4557844434808847 0.4557844434808847; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.22662118033597917 0.22662118033597917 0.22662118033597917 0.22662118033597917; 0.2222222222222222 0.2222222222222222 0.2222222222222222 0.2222222222222222; 0.3888888888888889 0.3888888888888889 0.3888888888888889 0.3888888888888889; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.44674403806434587 0.44674403806434587 0.44674403806434587 0.44674403806434587; 0.7054983335742233 0.7054983335742233 0.7054983335742233 0.7054983335742233; 0.5432306250213136 0.5432306250213136 0.5432306250213136 0.5432306250213136; 0.5 0.5 0.5 0.5; 0.7858960019821244 0.7858960019821244 0.7858960019821244 0.7858960019821244; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.5043989581137569 0.5043989581137569 0.5043989581137569 0.5043989581137569; 0.7598163116458259 0.7598163116458259 0.7598163116458259 0.7598163116458259; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.7054983335742233 0.7054983335742233 0.7054983335742233 0.7054983335742233; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.6111111111111112 0.6111111111111112 0.6111111111111112 0.6111111111111112; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.5 0.5 0.5 0.5; 1.0 1.0 1.0 1.0; 0.5388316669075566 0.5388316669075566 0.5388316669075566 0.5388316669075566; 0.5265067363733146 0.5265067363733146 0.5265067363733146 0.5265067363733146; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.4375034035217611 0.4375034035217611 0.4375034035217611 0.4375034035217611; 0.05555555555555555 0.05555555555555555 0.05555555555555555 0.05555555555555555; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444; 0.4444444444444444 0.4444444444444444 0.4444444444444444 0.4444444444444444]
return train_x, train_y
end
function build_model(args)
return Chain(
Dense(243, 500,relu),
Dense(500, 700),
Dense(700, 600),
Dense(600, 500),
Dense(500, 400),
Dense(400, 300),
Dense(300, 200),
Dense(200, 150),
Dense(150,args.num_parameters,NNlib.σ)
)
end
@kwdef mutable struct Args
η::Float64 = 1e-4 # learning rate
batchsize::Int = 3 # batch size
num_parameters::Int = 105 # Number of parameters of our model
num_measurements::Int = 243 # Number of all measurements
num_data::Int = 3 # number of data that is createtd
epochs::Int =100 # number of epochs
use_cuda::Bool = true # use gpu (if cuda available)
infotime::Int = 1 # report every `infotime` epochs
end
loss(ŷ, y) = huber_loss(ŷ, y)
function loss_and_accuracy(data_loader, model, device)
acc = 0
ls = 0.0f0
num = 0
i = 0
for (x, y) in data_loader
#println(size(y))
x, y = device(x), device(y)
ŷ = model(x)
ŷ_pow = power_taker(ŷ)
y_pow = power_taker(y)
acc_tmp = maximum(100 .*(abs.((ŷ_pow.-y_pow)./ŷ_pow)))
ls += loss(ŷ, y) * size(x)[end]
if acc_tmp > acc
acc = acc_tmp
end
num += size(x)[end]
i+=1
end
#println(i)
#return ls / num, acc
return ls / num, acc
end
function power_taker(list)
power_list = zeros(size(list)[1])
for i in 1:size(list)[1]
power_list[i] = 10^list[i]
end
return power_list
end
function train(; kws...)
args = Args(; kws...) # collect options in a struct for convenience
if CUDA.functional() && args.use_cuda
@info "Training on CUDA GPU"
CUDA.allowscalar(false)
device = gpu
else
@info "Training on CPU"
device = cpu
end
test_x, test_y = Data_importer()
#test_y = rand(args.num_parameters,1)
train_loader = DataLoader((test_x, test_y))
model = build_model(args) |> device
ps = Flux.params(model) # model's trainable parameters
## Optimizer
opt = ADAM(args.η)
best_loss = Inf
last_improvement = 1
device = cpu
function report(epoch)
train = loss_and_accuracy(train_loader, model, device)
println("Epoch: $epoch Train: $(train)")
end
## TRAINING
@info "Start Training"
report(0)
for epoch in 1:args.epochs
for (x, y) in train_loader
x, y = x |> device, y |> device
gs = Flux.gradient(ps) do
ŷ = model(x)
loss(ŷ, y)
end
Flux.Optimise.update!(opt, ps, gs)
end
epoch % args.infotime == 0 && report(epoch)
end
end
train()
It should be able to learn the connenction between the numbers really easily, however the error stays constant over the iterations. Does anyone have an idea what the problem is? Varying the learning rate didn’t really improve the result.