Sdílet prostřednictvím


microsoftml.rx_neural_network: Neurální síť

Usage

microsoftml.rx_neural_network(formula: str,
    data: [revoscalepy.datasource.RxDataSource.RxDataSource,
    pandas.core.frame.DataFrame], method: ['binary', 'multiClass',
    'regression'] = 'binary', num_hidden_nodes: int = 100,
    num_iterations: int = 100,
    optimizer: [<function adadelta_optimizer at 0x0000007156EAC048>,
    <function sgd_optimizer at 0x0000007156E9FB70>] = {'Name': 'SgdOptimizer',
    'Settings': {}}, net_definition: str = None,
    init_wts_diameter: float = 0.1, max_norm: float = 0,
    acceleration: [<function avx_math at 0x0000007156E9FEA0>,
    <function clr_math at 0x0000007156EAC158>,
    <function gpu_math at 0x0000007156EAC1E0>,
    <function mkl_math at 0x0000007156EAC268>,
    <function sse_math at 0x0000007156EAC2F0>] = {'Name': 'AvxMath',
    'Settings': {}}, mini_batch_size: int = 1, normalize: ['No',
    'Warn', 'Auto', 'Yes'] = 'Auto', ml_transforms: list = None,
    ml_transform_vars: list = None, row_selection: str = None,
    transforms: dict = None, transform_objects: dict = None,
    transform_function: str = None,
    transform_variables: list = None,
    transform_packages: list = None,
    transform_environment: dict = None, blocks_per_read: int = None,
    report_progress: int = None, verbose: int = 1,
    ensemble: microsoftml.modules.ensemble.EnsembleControl = None,
    compute_context: revoscalepy.computecontext.RxComputeContext.RxComputeContext = None)

Description

Neurální sítě pro regresní modelování a pro binární a vícetřídní klasifikaci.

Podrobnosti

Neurální síť je třída prediktivních modelů inspirovaných lidským mozkem. Neurální síť může být reprezentována jako vážený směrovaný graf. Každý uzel v grafu se nazývá neuron. Neurony v grafu jsou uspořádány ve vrstvách, kde neurony v jedné vrstvě jsou spojeny váženým okrajem (váhy mohou být 0 nebo kladná čísla) neuronům v další vrstvě. První vrstva se nazývá vstupní vrstva a každý neuron ve vstupní vrstvě odpovídá jedné z vlastností. Poslední vrstva funkce se nazývá výstupní vrstva. Takže v případě binárních neurálních sítí obsahuje dva výstupní neurony, jeden pro každou třídu, jehož hodnoty jsou pravděpodobnosti, že patří do každé třídy. Zbývající vrstvy se nazývají skryté vrstvy. Hodnoty neuronů ve skrytých vrstvách a ve výstupní vrstvě jsou nastaveny výpočtem váženého součtu hodnot neuronů v předchozí vrstvě a použitím aktivační funkce na tento vážený součet. Model neurální sítě je definován strukturou grafu (konkrétně počtem skrytých vrstev a počtem neuronů v každé skryté vrstvě), výběrem aktivační funkce a váhami na hranách grafu. Algoritmus neurální sítě se snaží naučit optimální váhy na hranách na základě trénovacích dat.

I když neurální sítě jsou široce známé pro hluboké učení a modelování složitých problémů, jako je rozpoznávání obrázků, jsou také snadno přizpůsobené regresním problémům. Libovolnou třídu statistických modelů lze považovat za neurální síť, pokud používají adaptivní váhy a mohou přibližné nelineární funkce jejich vstupů. Regrese neurální sítě je zvláště vhodná pro problémy, kdy tradiční regresní model nemůže vyhovovat řešení.

Arguments

vzorec

Vzorec, jak je popsáno v revoscalepy.rx_formula. Výrazy interakce a F() v microsoftml se v současné době nepodporují.

data

Objekt zdroje dat nebo řetězec znaku určující soubor .xdf nebo objekt datového rámce.

metoda

Znakový řetězec označující typ Fast Tree:

  • "binary" pro výchozí neurální síť binární klasifikace.

  • "multiClass" pro neurální síť klasifikace s více třídami.

  • "regression" pro regresní neurální síť.

num_hidden_nodes

Výchozí počet skrytých uzlů v neurální síti. Výchozí hodnota je 100.

num_iterations

Počet iterací v úplné trénovací sadě. Výchozí hodnota je 100.

optimizer

Seznam určující sgd algoritmus optimalizace.adaptive Tento seznam lze vytvořit pomocí sgd_optimizer nebo adadelta_optimizer. Výchozí hodnota je sgd.

net_definition

Definice Net# struktury neurální sítě. Další informace o jazyce Net# naleznete v referenční příručce

init_wts_diameter

Nastaví průměr počátečních hmotností, který určuje rozsah, ze kterého se hodnoty vykreslí pro počáteční hmotnost učení. Váhy se inicializují náhodně z tohoto rozsahu. Výchozí hodnota je 0,1.

max_norm

Určuje horní mez, která omezí normu příchozího vektoru hmotnosti v každé skryté jednotce. To může být velmi důležité v maximálních neurálních sítích i v případech, kdy trénování vytváří nevázané váhy.

Zrychlení

Určuje typ hardwarové akcelerace, která se má použít. Možné hodnoty jsou "sse_math" a "gpu_math". Pro akceleraci GPU se doporučuje použít miniBatchSize větší než jeden. Pokud chcete použít akceleraci GPU, je potřeba provést další kroky ručního nastavení:

  • Stáhněte a nainstalujte NVidia CUDA Toolkit 6.5 (CUDA Toolkit).

  • Stáhněte a nainstalujte knihovnu NVidia cuDNN v2 (cudnn Library).

  • Vyhledejte adresář libs balíčku microsoftml voláním import microsoftml, os, os.path.join(microsoftml.__path__[0], "mxLibs").

  • Zkopírujte cublas64_65.dll, cudart64_65.dll a cusparse64_65.dll ze sady CUDA Toolkit 6.5 do adresáře libs balíčku microsoftml.

  • Zkopírujte cudnn64_65.dll z knihovny cuDNN v2 do adresáře libs balíčku microsoftml.

mini_batch_size

Nastaví velikost minidávkové dávky. Doporučené hodnoty jsou mezi 1 a 256. Tento parametr se používá pouze v případě, že akcelerace je GPU. Nastavení tohoto parametru na vyšší hodnotu zlepšuje rychlost trénování, ale může negativně ovlivnit přesnost. Výchozí hodnota je 1.

Normalizovat

Určuje typ použité automatické normalizace:

  • "Warn": Pokud je potřeba normalizace, provede se automaticky. Toto je výchozí volba.

  • "No": Neprovádí se normalizace.

  • "Yes": normalizace se provádí.

  • "Auto": Pokud je potřeba normalizace, zobrazí se zpráva s upozorněním, ale normalizace se neprovede.

Normalizace rescales disparate data ranges to a standard scale. Škálování funkcí zajišťuje, že vzdálenosti mezi datovými body jsou proporcionální a umožňují mnohem rychleji konvergovat různé metody optimalizace, jako je gradientní sestup. Při normalizaci MaxMin se použije normalizátor. Normalizuje hodnoty v intervalu [a, b] where -1 <= a <= 0 a 0 <= b <= 1 .b - a = 1 Tento normalizátor zachovává sparsity tím, že namapuje nulu na nulu.

ml_transforms

Určuje seznam transformací MicrosoftML, které se mají provést s daty před trénováním nebo Žádná , pokud se neprovedou žádné transformace. Viz featurize_text, categoricala categorical_hash, pro transformace, které jsou podporovány. Tyto transformace se provádějí po všech zadaných transformacích Pythonu. Výchozí hodnota je None.

ml_transform_vars

Určuje vektor znaku názvů proměnných, který má být použit v ml_transforms nebo None , pokud se žádný použít. Výchozí hodnota je None.

row_selection

NEPODPORUJE SE. Určuje řádky (pozorování) ze sady dat, které má model používat s názvem logické proměnné ze sady dat (v uvozovkách) nebo logickým výrazem pomocí proměnných v sadě dat. Například:

  • row_selection = "old"použije pouze pozorování, ve kterých je Truehodnota proměnné old .

  • row_selection = (age > 20) & (age < 65) & (log(income) > 10) používá pouze pozorování, ve kterých je hodnota age proměnné mezi 20 a 65 a hodnotou logincome proměnné je větší než 10.

Výběr řádku se provede po zpracování všech transformací dat (viz argumenty transforms nebo transform_function). Stejně jako u všech výrazů row_selection je možné definovat mimo volání funkce pomocí expression funkce.

transformuje

NEPODPORUJE SE. Výraz formuláře, který představuje první kolo transformací proměnných. Stejně jako u všech výrazů transforms je možné definovat (nebo row_selection) mimo volání funkce pomocí expression funkce.

transform_objects

NEPODPORUJE SE. Pojmenovaný seznam obsahující objekty, na které lze odkazovat pomocí transforms, transform_functiona row_selection.

transform_function

Proměnná transformační funkce.

transform_variables

Znakový vektor vstupních proměnných množiny dat potřebných pro transformační funkci.

transform_packages

NEPODPORUJE SE. Vektor znaku určující další balíčky Pythonu (mimo balíčky zadané v RxOptions.get_option("transform_packages")) k dispozici a předem načtený pro použití v transformačních funkcích proměnných. Například explicitně definované v funkcích revoscalepy prostřednictvím jejich transforms a transform_function argumentů nebo těch, které jsou definovány implicitně prostřednictvím jejich formula nebo row_selection argumentů. Argumentem transform_packages může být také Žádný, což znamená, že nejsou předem načteny žádné balíčky mimo RxOptions.get_option("transform_packages") .

transform_environment

NEPODPORUJE SE. Uživatelem definované prostředí, které bude sloužit jako nadřazené všem prostředím vyvinutým interně a které se používají k transformaci proměnných dat. Pokud transform_environment = Nonemísto toho použijete nové prostředí hash s nadřazenou hodnotou revoscalepy.baseenvis.

blocks_per_read

Určuje početblokůch

report_progress

Celočíselná hodnota, která určuje úroveň generování sestav o průběhu zpracování řádků:

  • 0: Nebyl hlášen žádný průběh.

  • 1: Počet zpracovaných řádků se vytiskne a aktualizuje.

  • 2: Jsou hlášeny řádky zpracovávané a časování.

  • 3: Jsou hlášeny řádky zpracovávané a všechna časování.

podrobný

Celočíselná hodnota, která určuje požadovanou velikost výstupu. Pokud 0se během výpočtů nevytiskne žádný podrobný výstup. Celočíselné hodnoty, od 1 které se 4 poskytují rostoucí množství informací.

compute_context

Nastaví kontext, ve kterém se výpočty spouštějí, zadané pomocí platné revoscalepy. RxComputeContext. V současné době místní a revoscalepy. Podporují se výpočetní kontexty RxInSqlServer .

Soubor

Kontrolní parametry pro přemíscení

Návraty

Objekt NeuralNetwork s natrénovaným modelem.

Poznámka:

Tento algoritmus je s jedním vláknem a nebude se pokoušet načíst celou datovou sadu do paměti.

Viz také

adadelta_optimizer, sgd_optimizer, avx_math, , clr_mathgpu_math, mkl_math, sse_math, , rx_predict.

Odkazy

Wikipedie: Umělá neurální síť

Příklad binární klasifikace

'''
Binary Classification.
'''
import numpy
import pandas
from microsoftml import rx_neural_network, rx_predict
from revoscalepy.etl.RxDataStep import rx_data_step
from microsoftml.datasets.datasets import get_dataset

infert = get_dataset("infert")

import sklearn
if sklearn.__version__ < "0.18":
    from sklearn.cross_validation import train_test_split
else:
    from sklearn.model_selection import train_test_split

infertdf = infert.as_df()
infertdf["isCase"] = infertdf.case == 1
data_train, data_test, y_train, y_test = train_test_split(infertdf, infertdf.isCase)

forest_model = rx_neural_network(
    formula=" isCase ~ age + parity + education + spontaneous + induced ",
    data=data_train)
    
# RuntimeError: The type (RxTextData) for file is not supported.
score_ds = rx_predict(forest_model, data=data_test,
                     extra_vars_to_write=["isCase", "Score"])
                     
# Print the first five rows
print(rx_data_step(score_ds, number_rows_read=5))

Výstup:

Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.
Beginning processing data.
Rows Read: 186, Read Time: 0, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 186, Read Time: 0, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 186, Read Time: 0, Transform Time: 0
Beginning processing data.
Using: AVX Math

***** Net definition *****
  input Data [5];
  hidden H [100] sigmoid { // Depth 1
    from Data all;
  }
  output Result [1] sigmoid { // Depth 0
    from H all;
  }
***** End net definition *****
Input count: 5
Output count: 1
Output Function: Sigmoid
Loss Function: LogLoss
PreTrainer: NoPreTrainer
___________________________________________________________________
Starting training...
Learning rate: 0.001000
Momentum: 0.000000
InitWtsDiameter: 0.100000
___________________________________________________________________
Initializing 1 Hidden Layers, 701 Weights...
Estimated Pre-training MeanError = 0.742343
Iter:1/100, MeanErr=0.680245(-8.37%), 119.87M WeightUpdates/sec
Iter:2/100, MeanErr=0.637843(-6.23%), 122.52M WeightUpdates/sec
Iter:3/100, MeanErr=0.635404(-0.38%), 122.24M WeightUpdates/sec
Iter:4/100, MeanErr=0.634980(-0.07%), 73.36M WeightUpdates/sec
Iter:5/100, MeanErr=0.635287(0.05%), 128.26M WeightUpdates/sec
Iter:6/100, MeanErr=0.634572(-0.11%), 131.05M WeightUpdates/sec
Iter:7/100, MeanErr=0.634827(0.04%), 124.27M WeightUpdates/sec
Iter:8/100, MeanErr=0.635359(0.08%), 123.69M WeightUpdates/sec
Iter:9/100, MeanErr=0.635244(-0.02%), 119.35M WeightUpdates/sec
Iter:10/100, MeanErr=0.634712(-0.08%), 127.80M WeightUpdates/sec
Iter:11/100, MeanErr=0.635105(0.06%), 122.69M WeightUpdates/sec
Iter:12/100, MeanErr=0.635226(0.02%), 98.61M WeightUpdates/sec
Iter:13/100, MeanErr=0.634977(-0.04%), 127.88M WeightUpdates/sec
Iter:14/100, MeanErr=0.634347(-0.10%), 123.25M WeightUpdates/sec
Iter:15/100, MeanErr=0.634891(0.09%), 124.27M WeightUpdates/sec
Iter:16/100, MeanErr=0.635116(0.04%), 123.06M WeightUpdates/sec
Iter:17/100, MeanErr=0.633770(-0.21%), 122.05M WeightUpdates/sec
Iter:18/100, MeanErr=0.634992(0.19%), 128.79M WeightUpdates/sec
Iter:19/100, MeanErr=0.634385(-0.10%), 122.95M WeightUpdates/sec
Iter:20/100, MeanErr=0.634752(0.06%), 127.14M WeightUpdates/sec
Iter:21/100, MeanErr=0.635043(0.05%), 123.44M WeightUpdates/sec
Iter:22/100, MeanErr=0.634845(-0.03%), 121.81M WeightUpdates/sec
Iter:23/100, MeanErr=0.634850(0.00%), 125.11M WeightUpdates/sec
Iter:24/100, MeanErr=0.634617(-0.04%), 122.18M WeightUpdates/sec
Iter:25/100, MeanErr=0.634675(0.01%), 125.69M WeightUpdates/sec
Iter:26/100, MeanErr=0.634911(0.04%), 122.44M WeightUpdates/sec
Iter:27/100, MeanErr=0.634311(-0.09%), 121.90M WeightUpdates/sec
Iter:28/100, MeanErr=0.634798(0.08%), 123.54M WeightUpdates/sec
Iter:29/100, MeanErr=0.634674(-0.02%), 127.53M WeightUpdates/sec
Iter:30/100, MeanErr=0.634546(-0.02%), 100.96M WeightUpdates/sec
Iter:31/100, MeanErr=0.634859(0.05%), 124.40M WeightUpdates/sec
Iter:32/100, MeanErr=0.634747(-0.02%), 128.21M WeightUpdates/sec
Iter:33/100, MeanErr=0.634842(0.02%), 125.82M WeightUpdates/sec
Iter:34/100, MeanErr=0.634703(-0.02%), 77.48M WeightUpdates/sec
Iter:35/100, MeanErr=0.634804(0.02%), 122.21M WeightUpdates/sec
Iter:36/100, MeanErr=0.634690(-0.02%), 112.48M WeightUpdates/sec
Iter:37/100, MeanErr=0.634654(-0.01%), 119.18M WeightUpdates/sec
Iter:38/100, MeanErr=0.634885(0.04%), 137.19M WeightUpdates/sec
Iter:39/100, MeanErr=0.634723(-0.03%), 113.80M WeightUpdates/sec
Iter:40/100, MeanErr=0.634714(0.00%), 127.50M WeightUpdates/sec
Iter:41/100, MeanErr=0.634794(0.01%), 129.54M WeightUpdates/sec
Iter:42/100, MeanErr=0.633835(-0.15%), 133.05M WeightUpdates/sec
Iter:43/100, MeanErr=0.634401(0.09%), 128.95M WeightUpdates/sec
Iter:44/100, MeanErr=0.634575(0.03%), 123.42M WeightUpdates/sec
Iter:45/100, MeanErr=0.634673(0.02%), 123.78M WeightUpdates/sec
Iter:46/100, MeanErr=0.634692(0.00%), 119.04M WeightUpdates/sec
Iter:47/100, MeanErr=0.634476(-0.03%), 122.95M WeightUpdates/sec
Iter:48/100, MeanErr=0.634583(0.02%), 97.87M WeightUpdates/sec
Iter:49/100, MeanErr=0.634706(0.02%), 121.41M WeightUpdates/sec
Iter:50/100, MeanErr=0.634564(-0.02%), 120.58M WeightUpdates/sec
Iter:51/100, MeanErr=0.634118(-0.07%), 120.17M WeightUpdates/sec
Iter:52/100, MeanErr=0.634699(0.09%), 127.27M WeightUpdates/sec
Iter:53/100, MeanErr=0.634123(-0.09%), 110.51M WeightUpdates/sec
Iter:54/100, MeanErr=0.634390(0.04%), 123.74M WeightUpdates/sec
Iter:55/100, MeanErr=0.634461(0.01%), 113.66M WeightUpdates/sec
Iter:56/100, MeanErr=0.634415(-0.01%), 118.61M WeightUpdates/sec
Iter:57/100, MeanErr=0.634453(0.01%), 114.99M WeightUpdates/sec
Iter:58/100, MeanErr=0.634478(0.00%), 104.53M WeightUpdates/sec
Iter:59/100, MeanErr=0.634010(-0.07%), 124.62M WeightUpdates/sec
Iter:60/100, MeanErr=0.633901(-0.02%), 118.93M WeightUpdates/sec
Iter:61/100, MeanErr=0.634088(0.03%), 40.46M WeightUpdates/sec
Iter:62/100, MeanErr=0.634046(-0.01%), 94.65M WeightUpdates/sec
Iter:63/100, MeanErr=0.634233(0.03%), 27.18M WeightUpdates/sec
Iter:64/100, MeanErr=0.634596(0.06%), 123.94M WeightUpdates/sec
Iter:65/100, MeanErr=0.634185(-0.06%), 125.01M WeightUpdates/sec
Iter:66/100, MeanErr=0.634469(0.04%), 119.41M WeightUpdates/sec
Iter:67/100, MeanErr=0.634333(-0.02%), 124.11M WeightUpdates/sec
Iter:68/100, MeanErr=0.634203(-0.02%), 112.68M WeightUpdates/sec
Iter:69/100, MeanErr=0.633854(-0.05%), 118.62M WeightUpdates/sec
Iter:70/100, MeanErr=0.634319(0.07%), 123.59M WeightUpdates/sec
Iter:71/100, MeanErr=0.634423(0.02%), 122.51M WeightUpdates/sec
Iter:72/100, MeanErr=0.634388(-0.01%), 126.15M WeightUpdates/sec
Iter:73/100, MeanErr=0.634230(-0.02%), 126.51M WeightUpdates/sec
Iter:74/100, MeanErr=0.634011(-0.03%), 128.32M WeightUpdates/sec
Iter:75/100, MeanErr=0.634294(0.04%), 127.48M WeightUpdates/sec
Iter:76/100, MeanErr=0.634372(0.01%), 123.51M WeightUpdates/sec
Iter:77/100, MeanErr=0.632020(-0.37%), 122.12M WeightUpdates/sec
Iter:78/100, MeanErr=0.633770(0.28%), 119.55M WeightUpdates/sec
Iter:79/100, MeanErr=0.633504(-0.04%), 124.21M WeightUpdates/sec
Iter:80/100, MeanErr=0.634154(0.10%), 125.94M WeightUpdates/sec
Iter:81/100, MeanErr=0.633491(-0.10%), 120.83M WeightUpdates/sec
Iter:82/100, MeanErr=0.634212(0.11%), 128.60M WeightUpdates/sec
Iter:83/100, MeanErr=0.634138(-0.01%), 73.58M WeightUpdates/sec
Iter:84/100, MeanErr=0.634244(0.02%), 124.08M WeightUpdates/sec
Iter:85/100, MeanErr=0.634065(-0.03%), 96.43M WeightUpdates/sec
Iter:86/100, MeanErr=0.634174(0.02%), 124.28M WeightUpdates/sec
Iter:87/100, MeanErr=0.633966(-0.03%), 125.24M WeightUpdates/sec
Iter:88/100, MeanErr=0.633989(0.00%), 130.31M WeightUpdates/sec
Iter:89/100, MeanErr=0.633767(-0.04%), 115.73M WeightUpdates/sec
Iter:90/100, MeanErr=0.633831(0.01%), 122.81M WeightUpdates/sec
Iter:91/100, MeanErr=0.633219(-0.10%), 114.91M WeightUpdates/sec
Iter:92/100, MeanErr=0.633589(0.06%), 93.29M WeightUpdates/sec
Iter:93/100, MeanErr=0.634086(0.08%), 123.31M WeightUpdates/sec
Iter:94/100, MeanErr=0.634075(0.00%), 120.99M WeightUpdates/sec
Iter:95/100, MeanErr=0.634071(0.00%), 122.49M WeightUpdates/sec
Iter:96/100, MeanErr=0.633523(-0.09%), 116.48M WeightUpdates/sec
Iter:97/100, MeanErr=0.634103(0.09%), 128.85M WeightUpdates/sec
Iter:98/100, MeanErr=0.633836(-0.04%), 123.87M WeightUpdates/sec
Iter:99/100, MeanErr=0.633772(-0.01%), 128.17M WeightUpdates/sec
Iter:100/100, MeanErr=0.633684(-0.01%), 123.65M WeightUpdates/sec
Done!
Estimated Post-training MeanError = 0.631268
___________________________________________________________________
Not training a calibrator because it is not needed.
Elapsed time: 00:00:00.2454094
Elapsed time: 00:00:00.0082325
Beginning processing data.
Rows Read: 62, Read Time: 0.001, Transform Time: 0
Beginning processing data.
Elapsed time: 00:00:00.0297006
Finished writing 62 rows.
Writing completed.
Rows Read: 5, Total Rows Processed: 5, Total Chunk Time: 0.001 seconds 
  isCase PredictedLabel     Score  Probability
0   True          False -0.689636     0.334114
1   True          False -0.710219     0.329551
2   True          False -0.712912     0.328956
3  False          False -0.700765     0.331643
4   True          False -0.689783     0.334081

Příklad klasifikace MultiClass

'''
MultiClass Classification.
'''
import numpy
import pandas
from microsoftml import rx_neural_network, rx_predict
from revoscalepy.etl.RxDataStep import rx_data_step
from microsoftml.datasets.datasets import get_dataset

iris = get_dataset("iris")

import sklearn
if sklearn.__version__ < "0.18":
    from sklearn.cross_validation import train_test_split
else:
    from sklearn.model_selection import train_test_split

irisdf = iris.as_df()
irisdf["Species"] = irisdf["Species"].astype("category")
data_train, data_test, y_train, y_test = train_test_split(irisdf, irisdf.Species)

model = rx_neural_network(
    formula="  Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width ",
    method="multiClass",
    data=data_train)
    
# RuntimeError: The type (RxTextData) for file is not supported.
score_ds = rx_predict(model, data=data_test,
                     extra_vars_to_write=["Species", "Score"])
                     
# Print the first five rows
print(rx_data_step(score_ds, number_rows_read=5))

Výstup:

Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.
Beginning processing data.
Rows Read: 112, Read Time: 0.001, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 112, Read Time: 0, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 112, Read Time: 0, Transform Time: 0
Beginning processing data.
Using: AVX Math

***** Net definition *****
  input Data [4];
  hidden H [100] sigmoid { // Depth 1
    from Data all;
  }
  output Result [3] softmax { // Depth 0
    from H all;
  }
***** End net definition *****
Input count: 4
Output count: 3
Output Function: SoftMax
Loss Function: LogLoss
PreTrainer: NoPreTrainer
___________________________________________________________________
Starting training...
Learning rate: 0.001000
Momentum: 0.000000
InitWtsDiameter: 0.100000
___________________________________________________________________
Initializing 1 Hidden Layers, 803 Weights...
Estimated Pre-training MeanError = 1.949606
Iter:1/100, MeanErr=1.937924(-0.60%), 98.43M WeightUpdates/sec
Iter:2/100, MeanErr=1.921153(-0.87%), 96.21M WeightUpdates/sec
Iter:3/100, MeanErr=1.920000(-0.06%), 95.55M WeightUpdates/sec
Iter:4/100, MeanErr=1.917267(-0.14%), 81.25M WeightUpdates/sec
Iter:5/100, MeanErr=1.917611(0.02%), 102.44M WeightUpdates/sec
Iter:6/100, MeanErr=1.918476(0.05%), 106.16M WeightUpdates/sec
Iter:7/100, MeanErr=1.916096(-0.12%), 97.85M WeightUpdates/sec
Iter:8/100, MeanErr=1.919486(0.18%), 77.99M WeightUpdates/sec
Iter:9/100, MeanErr=1.916452(-0.16%), 95.67M WeightUpdates/sec
Iter:10/100, MeanErr=1.916024(-0.02%), 102.06M WeightUpdates/sec
Iter:11/100, MeanErr=1.917155(0.06%), 99.21M WeightUpdates/sec
Iter:12/100, MeanErr=1.918543(0.07%), 99.25M WeightUpdates/sec
Iter:13/100, MeanErr=1.919120(0.03%), 85.38M WeightUpdates/sec
Iter:14/100, MeanErr=1.917713(-0.07%), 103.00M WeightUpdates/sec
Iter:15/100, MeanErr=1.917675(0.00%), 98.70M WeightUpdates/sec
Iter:16/100, MeanErr=1.917982(0.02%), 99.10M WeightUpdates/sec
Iter:17/100, MeanErr=1.916254(-0.09%), 103.41M WeightUpdates/sec
Iter:18/100, MeanErr=1.915691(-0.03%), 102.00M WeightUpdates/sec
Iter:19/100, MeanErr=1.914844(-0.04%), 86.64M WeightUpdates/sec
Iter:20/100, MeanErr=1.919268(0.23%), 94.68M WeightUpdates/sec
Iter:21/100, MeanErr=1.918748(-0.03%), 108.11M WeightUpdates/sec
Iter:22/100, MeanErr=1.917997(-0.04%), 96.33M WeightUpdates/sec
Iter:23/100, MeanErr=1.914987(-0.16%), 82.84M WeightUpdates/sec
Iter:24/100, MeanErr=1.916550(0.08%), 99.70M WeightUpdates/sec
Iter:25/100, MeanErr=1.915401(-0.06%), 96.69M WeightUpdates/sec
Iter:26/100, MeanErr=1.916092(0.04%), 101.62M WeightUpdates/sec
Iter:27/100, MeanErr=1.916381(0.02%), 98.81M WeightUpdates/sec
Iter:28/100, MeanErr=1.917414(0.05%), 102.29M WeightUpdates/sec
Iter:29/100, MeanErr=1.917316(-0.01%), 100.17M WeightUpdates/sec
Iter:30/100, MeanErr=1.916507(-0.04%), 82.09M WeightUpdates/sec
Iter:31/100, MeanErr=1.915786(-0.04%), 98.33M WeightUpdates/sec
Iter:32/100, MeanErr=1.917581(0.09%), 101.70M WeightUpdates/sec
Iter:33/100, MeanErr=1.913680(-0.20%), 79.94M WeightUpdates/sec
Iter:34/100, MeanErr=1.917264(0.19%), 102.54M WeightUpdates/sec
Iter:35/100, MeanErr=1.917377(0.01%), 100.67M WeightUpdates/sec
Iter:36/100, MeanErr=1.912060(-0.28%), 70.37M WeightUpdates/sec
Iter:37/100, MeanErr=1.917009(0.26%), 80.80M WeightUpdates/sec
Iter:38/100, MeanErr=1.916216(-0.04%), 94.56M WeightUpdates/sec
Iter:39/100, MeanErr=1.916362(0.01%), 28.22M WeightUpdates/sec
Iter:40/100, MeanErr=1.910658(-0.30%), 100.87M WeightUpdates/sec
Iter:41/100, MeanErr=1.916375(0.30%), 85.99M WeightUpdates/sec
Iter:42/100, MeanErr=1.916257(-0.01%), 102.06M WeightUpdates/sec
Iter:43/100, MeanErr=1.914505(-0.09%), 99.86M WeightUpdates/sec
Iter:44/100, MeanErr=1.914638(0.01%), 103.11M WeightUpdates/sec
Iter:45/100, MeanErr=1.915141(0.03%), 107.62M WeightUpdates/sec
Iter:46/100, MeanErr=1.915119(0.00%), 99.65M WeightUpdates/sec
Iter:47/100, MeanErr=1.915379(0.01%), 107.03M WeightUpdates/sec
Iter:48/100, MeanErr=1.912565(-0.15%), 104.78M WeightUpdates/sec
Iter:49/100, MeanErr=1.915466(0.15%), 110.43M WeightUpdates/sec
Iter:50/100, MeanErr=1.914038(-0.07%), 98.44M WeightUpdates/sec
Iter:51/100, MeanErr=1.915015(0.05%), 96.28M WeightUpdates/sec
Iter:52/100, MeanErr=1.913771(-0.06%), 89.27M WeightUpdates/sec
Iter:53/100, MeanErr=1.911621(-0.11%), 72.67M WeightUpdates/sec
Iter:54/100, MeanErr=1.914969(0.18%), 111.17M WeightUpdates/sec
Iter:55/100, MeanErr=1.913894(-0.06%), 98.68M WeightUpdates/sec
Iter:56/100, MeanErr=1.914871(0.05%), 95.41M WeightUpdates/sec
Iter:57/100, MeanErr=1.912898(-0.10%), 80.72M WeightUpdates/sec
Iter:58/100, MeanErr=1.913334(0.02%), 103.71M WeightUpdates/sec
Iter:59/100, MeanErr=1.913362(0.00%), 99.57M WeightUpdates/sec
Iter:60/100, MeanErr=1.913915(0.03%), 106.21M WeightUpdates/sec
Iter:61/100, MeanErr=1.913310(-0.03%), 112.27M WeightUpdates/sec
Iter:62/100, MeanErr=1.913395(0.00%), 50.86M WeightUpdates/sec
Iter:63/100, MeanErr=1.912814(-0.03%), 58.91M WeightUpdates/sec
Iter:64/100, MeanErr=1.911468(-0.07%), 72.06M WeightUpdates/sec
Iter:65/100, MeanErr=1.912313(0.04%), 86.34M WeightUpdates/sec
Iter:66/100, MeanErr=1.913320(0.05%), 114.39M WeightUpdates/sec
Iter:67/100, MeanErr=1.912914(-0.02%), 105.97M WeightUpdates/sec
Iter:68/100, MeanErr=1.909881(-0.16%), 105.73M WeightUpdates/sec
Iter:69/100, MeanErr=1.911649(0.09%), 105.23M WeightUpdates/sec
Iter:70/100, MeanErr=1.911192(-0.02%), 110.24M WeightUpdates/sec
Iter:71/100, MeanErr=1.912480(0.07%), 106.86M WeightUpdates/sec
Iter:72/100, MeanErr=1.909881(-0.14%), 97.28M WeightUpdates/sec
Iter:73/100, MeanErr=1.911678(0.09%), 109.57M WeightUpdates/sec
Iter:74/100, MeanErr=1.911137(-0.03%), 91.01M WeightUpdates/sec
Iter:75/100, MeanErr=1.910706(-0.02%), 99.41M WeightUpdates/sec
Iter:76/100, MeanErr=1.910869(0.01%), 84.18M WeightUpdates/sec
Iter:77/100, MeanErr=1.911643(0.04%), 105.07M WeightUpdates/sec
Iter:78/100, MeanErr=1.911438(-0.01%), 110.12M WeightUpdates/sec
Iter:79/100, MeanErr=1.909590(-0.10%), 84.16M WeightUpdates/sec
Iter:80/100, MeanErr=1.911181(0.08%), 92.30M WeightUpdates/sec
Iter:81/100, MeanErr=1.910534(-0.03%), 110.60M WeightUpdates/sec
Iter:82/100, MeanErr=1.909340(-0.06%), 54.07M WeightUpdates/sec
Iter:83/100, MeanErr=1.908275(-0.06%), 104.08M WeightUpdates/sec
Iter:84/100, MeanErr=1.910364(0.11%), 107.19M WeightUpdates/sec
Iter:85/100, MeanErr=1.910286(0.00%), 102.55M WeightUpdates/sec
Iter:86/100, MeanErr=1.909155(-0.06%), 79.72M WeightUpdates/sec
Iter:87/100, MeanErr=1.909384(0.01%), 102.37M WeightUpdates/sec
Iter:88/100, MeanErr=1.907751(-0.09%), 105.48M WeightUpdates/sec
Iter:89/100, MeanErr=1.910164(0.13%), 102.53M WeightUpdates/sec
Iter:90/100, MeanErr=1.907935(-0.12%), 105.03M WeightUpdates/sec
Iter:91/100, MeanErr=1.909510(0.08%), 99.97M WeightUpdates/sec
Iter:92/100, MeanErr=1.907405(-0.11%), 100.03M WeightUpdates/sec
Iter:93/100, MeanErr=1.905757(-0.09%), 113.21M WeightUpdates/sec
Iter:94/100, MeanErr=1.909167(0.18%), 107.86M WeightUpdates/sec
Iter:95/100, MeanErr=1.907593(-0.08%), 106.09M WeightUpdates/sec
Iter:96/100, MeanErr=1.908358(0.04%), 111.25M WeightUpdates/sec
Iter:97/100, MeanErr=1.906484(-0.10%), 95.81M WeightUpdates/sec
Iter:98/100, MeanErr=1.908239(0.09%), 105.89M WeightUpdates/sec
Iter:99/100, MeanErr=1.908508(0.01%), 103.05M WeightUpdates/sec
Iter:100/100, MeanErr=1.904747(-0.20%), 106.81M WeightUpdates/sec
Done!
Estimated Post-training MeanError = 1.896338
___________________________________________________________________
Not training a calibrator because it is not needed.
Elapsed time: 00:00:00.1620840
Elapsed time: 00:00:00.0096627
Beginning processing data.
Rows Read: 38, Read Time: 0, Transform Time: 0
Beginning processing data.
Elapsed time: 00:00:00.0312987
Finished writing 38 rows.
Writing completed.
Rows Read: 5, Total Rows Processed: 5, Total Chunk Time: Less than .001 seconds 
      Species   Score.0   Score.1   Score.2
0  versicolor  0.350161  0.339557  0.310282
1      setosa  0.358506  0.336593  0.304901
2   virginica  0.346957  0.340573  0.312470
3   virginica  0.346685  0.340748  0.312567
4   virginica  0.348469  0.340113  0.311417

Příklad regrese

'''
Regression.
'''
import numpy
import pandas
from microsoftml import rx_neural_network, rx_predict
from revoscalepy.etl.RxDataStep import rx_data_step
from microsoftml.datasets.datasets import get_dataset

attitude = get_dataset("attitude")

import sklearn
if sklearn.__version__ < "0.18":
    from sklearn.cross_validation import train_test_split
else:
    from sklearn.model_selection import train_test_split

attitudedf = attitude.as_df()
data_train, data_test = train_test_split(attitudedf)

model = rx_neural_network(
    formula="rating ~ complaints + privileges + learning + raises + critical + advance",
    method="regression",
    data=data_train)
    
# RuntimeError: The type (RxTextData) for file is not supported.
score_ds = rx_predict(model, data=data_test,
                     extra_vars_to_write=["rating"])
                     
# Print the first five rows
print(rx_data_step(score_ds, number_rows_read=5))

Výstup:

Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.
Beginning processing data.
Rows Read: 22, Read Time: 0, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 22, Read Time: 0.001, Transform Time: 0
Beginning processing data.
Beginning processing data.
Rows Read: 22, Read Time: 0, Transform Time: 0
Beginning processing data.
Using: AVX Math

***** Net definition *****
  input Data [6];
  hidden H [100] sigmoid { // Depth 1
    from Data all;
  }
  output Result [1] linear { // Depth 0
    from H all;
  }
***** End net definition *****
Input count: 6
Output count: 1
Output Function: Linear
Loss Function: SquaredLoss
PreTrainer: NoPreTrainer
___________________________________________________________________
Starting training...
Learning rate: 0.001000
Momentum: 0.000000
InitWtsDiameter: 0.100000
___________________________________________________________________
Initializing 1 Hidden Layers, 801 Weights...
Estimated Pre-training MeanError = 4458.793673
Iter:1/100, MeanErr=1624.747024(-63.56%), 27.30M WeightUpdates/sec
Iter:2/100, MeanErr=139.267390(-91.43%), 30.50M WeightUpdates/sec
Iter:3/100, MeanErr=116.382316(-16.43%), 29.16M WeightUpdates/sec
Iter:4/100, MeanErr=114.947244(-1.23%), 32.06M WeightUpdates/sec
Iter:5/100, MeanErr=112.886818(-1.79%), 32.96M WeightUpdates/sec
Iter:6/100, MeanErr=112.406547(-0.43%), 30.29M WeightUpdates/sec
Iter:7/100, MeanErr=110.502757(-1.69%), 30.92M WeightUpdates/sec
Iter:8/100, MeanErr=111.499645(0.90%), 31.20M WeightUpdates/sec
Iter:9/100, MeanErr=111.895816(0.36%), 32.46M WeightUpdates/sec
Iter:10/100, MeanErr=110.171443(-1.54%), 34.61M WeightUpdates/sec
Iter:11/100, MeanErr=106.975524(-2.90%), 22.14M WeightUpdates/sec
Iter:12/100, MeanErr=107.708220(0.68%), 7.73M WeightUpdates/sec
Iter:13/100, MeanErr=105.345097(-2.19%), 28.99M WeightUpdates/sec
Iter:14/100, MeanErr=109.937833(4.36%), 31.04M WeightUpdates/sec
Iter:15/100, MeanErr=106.672340(-2.97%), 30.04M WeightUpdates/sec
Iter:16/100, MeanErr=108.474555(1.69%), 32.41M WeightUpdates/sec
Iter:17/100, MeanErr=109.449054(0.90%), 31.60M WeightUpdates/sec
Iter:18/100, MeanErr=105.911830(-3.23%), 34.05M WeightUpdates/sec
Iter:19/100, MeanErr=106.045172(0.13%), 33.80M WeightUpdates/sec
Iter:20/100, MeanErr=108.360427(2.18%), 33.60M WeightUpdates/sec
Iter:21/100, MeanErr=106.506436(-1.71%), 33.77M WeightUpdates/sec
Iter:22/100, MeanErr=99.167335(-6.89%), 32.26M WeightUpdates/sec
Iter:23/100, MeanErr=108.115797(9.02%), 25.86M WeightUpdates/sec
Iter:24/100, MeanErr=106.292283(-1.69%), 31.03M WeightUpdates/sec
Iter:25/100, MeanErr=99.397875(-6.49%), 31.33M WeightUpdates/sec
Iter:26/100, MeanErr=104.805299(5.44%), 31.57M WeightUpdates/sec
Iter:27/100, MeanErr=101.385085(-3.26%), 22.92M WeightUpdates/sec
Iter:28/100, MeanErr=100.064656(-1.30%), 35.01M WeightUpdates/sec
Iter:29/100, MeanErr=100.519013(0.45%), 32.74M WeightUpdates/sec
Iter:30/100, MeanErr=99.273143(-1.24%), 35.12M WeightUpdates/sec
Iter:31/100, MeanErr=100.465649(1.20%), 33.68M WeightUpdates/sec
Iter:32/100, MeanErr=102.402320(1.93%), 33.79M WeightUpdates/sec
Iter:33/100, MeanErr=97.517196(-4.77%), 32.32M WeightUpdates/sec
Iter:34/100, MeanErr=102.597511(5.21%), 32.46M WeightUpdates/sec
Iter:35/100, MeanErr=96.187788(-6.25%), 32.32M WeightUpdates/sec
Iter:36/100, MeanErr=101.533507(5.56%), 21.44M WeightUpdates/sec
Iter:37/100, MeanErr=99.339624(-2.16%), 21.53M WeightUpdates/sec
Iter:38/100, MeanErr=98.049306(-1.30%), 15.27M WeightUpdates/sec
Iter:39/100, MeanErr=97.508282(-0.55%), 23.21M WeightUpdates/sec
Iter:40/100, MeanErr=99.894288(2.45%), 27.94M WeightUpdates/sec
Iter:41/100, MeanErr=95.190566(-4.71%), 32.47M WeightUpdates/sec
Iter:42/100, MeanErr=91.234977(-4.16%), 31.29M WeightUpdates/sec
Iter:43/100, MeanErr=98.824414(8.32%), 32.35M WeightUpdates/sec
Iter:44/100, MeanErr=96.759533(-2.09%), 22.37M WeightUpdates/sec
Iter:45/100, MeanErr=95.275106(-1.53%), 32.09M WeightUpdates/sec
Iter:46/100, MeanErr=95.749031(0.50%), 26.49M WeightUpdates/sec
Iter:47/100, MeanErr=96.267879(0.54%), 31.81M WeightUpdates/sec
Iter:48/100, MeanErr=97.383752(1.16%), 31.01M WeightUpdates/sec
Iter:49/100, MeanErr=96.605199(-0.80%), 32.05M WeightUpdates/sec
Iter:50/100, MeanErr=96.927400(0.33%), 32.42M WeightUpdates/sec
Iter:51/100, MeanErr=96.288491(-0.66%), 28.89M WeightUpdates/sec
Iter:52/100, MeanErr=92.751171(-3.67%), 33.68M WeightUpdates/sec
Iter:53/100, MeanErr=88.655001(-4.42%), 34.53M WeightUpdates/sec
Iter:54/100, MeanErr=90.923513(2.56%), 32.00M WeightUpdates/sec
Iter:55/100, MeanErr=91.627261(0.77%), 25.74M WeightUpdates/sec
Iter:56/100, MeanErr=91.132907(-0.54%), 30.00M WeightUpdates/sec
Iter:57/100, MeanErr=95.294092(4.57%), 33.13M WeightUpdates/sec
Iter:58/100, MeanErr=90.219024(-5.33%), 31.70M WeightUpdates/sec
Iter:59/100, MeanErr=92.727605(2.78%), 30.71M WeightUpdates/sec
Iter:60/100, MeanErr=86.910488(-6.27%), 33.07M WeightUpdates/sec
Iter:61/100, MeanErr=92.350984(6.26%), 32.46M WeightUpdates/sec
Iter:62/100, MeanErr=93.208298(0.93%), 31.08M WeightUpdates/sec
Iter:63/100, MeanErr=90.784723(-2.60%), 21.19M WeightUpdates/sec
Iter:64/100, MeanErr=88.685225(-2.31%), 33.17M WeightUpdates/sec
Iter:65/100, MeanErr=91.668555(3.36%), 30.65M WeightUpdates/sec
Iter:66/100, MeanErr=82.607568(-9.88%), 29.72M WeightUpdates/sec
Iter:67/100, MeanErr=88.787842(7.48%), 32.98M WeightUpdates/sec
Iter:68/100, MeanErr=88.793186(0.01%), 34.67M WeightUpdates/sec
Iter:69/100, MeanErr=88.918795(0.14%), 14.09M WeightUpdates/sec
Iter:70/100, MeanErr=87.121434(-2.02%), 33.02M WeightUpdates/sec
Iter:71/100, MeanErr=86.865602(-0.29%), 34.87M WeightUpdates/sec
Iter:72/100, MeanErr=87.261979(0.46%), 32.34M WeightUpdates/sec
Iter:73/100, MeanErr=87.812460(0.63%), 31.35M WeightUpdates/sec
Iter:74/100, MeanErr=87.818462(0.01%), 32.54M WeightUpdates/sec
Iter:75/100, MeanErr=87.085672(-0.83%), 34.80M WeightUpdates/sec
Iter:76/100, MeanErr=85.773668(-1.51%), 35.39M WeightUpdates/sec
Iter:77/100, MeanErr=85.338703(-0.51%), 34.59M WeightUpdates/sec
Iter:78/100, MeanErr=79.370105(-6.99%), 30.14M WeightUpdates/sec
Iter:79/100, MeanErr=83.026209(4.61%), 32.32M WeightUpdates/sec
Iter:80/100, MeanErr=89.776417(8.13%), 33.14M WeightUpdates/sec
Iter:81/100, MeanErr=85.447100(-4.82%), 32.32M WeightUpdates/sec
Iter:82/100, MeanErr=83.991969(-1.70%), 22.12M WeightUpdates/sec
Iter:83/100, MeanErr=85.065064(1.28%), 30.41M WeightUpdates/sec
Iter:84/100, MeanErr=83.762008(-1.53%), 31.29M WeightUpdates/sec
Iter:85/100, MeanErr=84.217726(0.54%), 34.92M WeightUpdates/sec
Iter:86/100, MeanErr=82.395181(-2.16%), 34.26M WeightUpdates/sec
Iter:87/100, MeanErr=82.979145(0.71%), 22.87M WeightUpdates/sec
Iter:88/100, MeanErr=83.656685(0.82%), 28.51M WeightUpdates/sec
Iter:89/100, MeanErr=81.132468(-3.02%), 32.43M WeightUpdates/sec
Iter:90/100, MeanErr=81.311106(0.22%), 30.91M WeightUpdates/sec
Iter:91/100, MeanErr=81.953897(0.79%), 31.98M WeightUpdates/sec
Iter:92/100, MeanErr=79.018074(-3.58%), 33.13M WeightUpdates/sec
Iter:93/100, MeanErr=78.220412(-1.01%), 31.47M WeightUpdates/sec
Iter:94/100, MeanErr=80.833884(3.34%), 25.16M WeightUpdates/sec
Iter:95/100, MeanErr=81.550135(0.89%), 32.64M WeightUpdates/sec
Iter:96/100, MeanErr=77.785628(-4.62%), 32.54M WeightUpdates/sec
Iter:97/100, MeanErr=76.438158(-1.73%), 34.34M WeightUpdates/sec
Iter:98/100, MeanErr=79.471621(3.97%), 33.12M WeightUpdates/sec
Iter:99/100, MeanErr=76.038475(-4.32%), 33.01M WeightUpdates/sec
Iter:100/100, MeanErr=75.349164(-0.91%), 32.68M WeightUpdates/sec
Done!
Estimated Post-training MeanError = 75.768932
___________________________________________________________________
Not training a calibrator because it is not needed.
Elapsed time: 00:00:00.1178557
Elapsed time: 00:00:00.0088299
Beginning processing data.
Rows Read: 8, Read Time: 0, Transform Time: 0
Beginning processing data.
Elapsed time: 00:00:00.0293893
Finished writing 8 rows.
Writing completed.
Rows Read: 5, Total Rows Processed: 5, Total Chunk Time: 0.001 seconds 
   rating      Score
0    82.0  70.120613
1    64.0  66.344688
2    68.0  68.862373
3    58.0  68.241341
4    63.0  67.196869

Optimalizaci

matematika