public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-N Don't normalize the data
-M Don't replace missing values
Modifier and Type | Field and Description |
---|---|
static int |
HINGE
the hinge loss function.
|
static int |
LOGLOSS
the log loss function.
|
static int |
SQUAREDLOSS
the squared loss funtion.
|
static Tag[] |
TAGS_SELECTION
Loss functions to choose from
|
Constructor and Description |
---|
SGD() |
Modifier and Type | Method and Description |
---|---|
void |
buildClassifier(Instances data)
Method for building the classifier.
|
double[] |
distributionForInstance(Instance inst)
Computes the distribution for a given instance
|
String |
dontNormalizeTipText()
Returns the tip text for this property
|
String |
dontReplaceMissingTipText()
Returns the tip text for this property
|
String |
epochsTipText()
Returns the tip text for this property
|
Capabilities |
getCapabilities()
Returns default capabilities of the classifier.
|
boolean |
getDontNormalize()
Get whether normalization has been turned off.
|
boolean |
getDontReplaceMissing()
Get whether global replacement of missing values has been
disabled.
|
int |
getEpochs()
Get current number of epochs
|
double |
getLambda()
Get the current value of lambda
|
double |
getLearningRate()
Get the learning rate.
|
SelectedTag |
getLossFunction()
Get the current loss function.
|
String[] |
getOptions()
Gets the current settings of the classifier.
|
String |
getRevision()
Returns the revision string.
|
String |
globalInfo()
Returns a string describing classifier
|
String |
lambdaTipText()
Returns the tip text for this property
|
String |
learningRateTipText()
Returns the tip text for this property
|
Enumeration<Option> |
listOptions()
Returns an enumeration describing the available options.
|
String |
lossFunctionTipText()
Returns the tip text for this property
|
static void |
main(String[] args)
Main method for testing this class.
|
void |
reset()
Reset the classifier.
|
void |
setDontNormalize(boolean m)
Turn normalization off/on.
|
void |
setDontReplaceMissing(boolean m)
Turn global replacement of missing values off/on.
|
void |
setEpochs(int e)
Set the number of epochs to use
|
void |
setLambda(double lambda)
Set the value of lambda to use
|
void |
setLearningRate(double lr)
Set the learning rate.
|
void |
setLossFunction(SelectedTag function)
Set the loss function to use.
|
void |
setOptions(String[] options)
Parses a given list of options.
|
String |
toString()
Prints out the classifier.
|
void |
updateClassifier(Instance instance)
Updates the classifier with the given instance.
|
getSeed, seedTipText, setSeed
classifyInstance, debugTipText, forName, getDebug, makeCopies, makeCopy, runClassifier, setDebug
public static final int HINGE
public static final int LOGLOSS
public static final int SQUAREDLOSS
public static final Tag[] TAGS_SELECTION
public Capabilities getCapabilities()
getCapabilities
in interface Classifier
getCapabilities
in interface CapabilitiesHandler
getCapabilities
in class AbstractClassifier
Capabilities
public String lambdaTipText()
public void setLambda(double lambda)
lambda
- the value of lambda to usepublic double getLambda()
public void setLearningRate(double lr)
lr
- the learning rate to use.public double getLearningRate()
public String learningRateTipText()
public String epochsTipText()
public void setEpochs(int e)
e
- the number of epochs to usepublic int getEpochs()
public void setDontNormalize(boolean m)
m
- true if normalization is to be disabled.public boolean getDontNormalize()
public String dontNormalizeTipText()
public void setDontReplaceMissing(boolean m)
m
- true if global replacement of missing values is to be
turned off.public boolean getDontReplaceMissing()
public String dontReplaceMissingTipText()
public void setLossFunction(SelectedTag function)
function
- the loss function to use.public SelectedTag getLossFunction()
public String lossFunctionTipText()
public Enumeration<Option> listOptions()
listOptions
in interface OptionHandler
listOptions
in class RandomizableClassifier
public void setOptions(String[] options) throws Exception
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-N Don't normalize the data
-M Don't replace missing values
setOptions
in interface OptionHandler
setOptions
in class RandomizableClassifier
options
- the list of options as an array of stringsException
- if an option is not supportedpublic String[] getOptions()
getOptions
in interface OptionHandler
getOptions
in class RandomizableClassifier
public String globalInfo()
public void reset()
public void buildClassifier(Instances data) throws Exception
buildClassifier
in interface Classifier
data
- the set of training instances.Exception
- if the classifier can't be built successfully.public void updateClassifier(Instance instance) throws Exception
updateClassifier
in interface UpdateableClassifier
instance
- the new training instance to include in the modelException
- if the instance could not be incorporated in
the model.public double[] distributionForInstance(Instance inst) throws Exception
distributionForInstance
in interface Classifier
distributionForInstance
in class AbstractClassifier
instance
- the instance for which distribution is computedException
- if the distribution can't be computed successfullypublic String toString()
public String getRevision()
getRevision
in interface RevisionHandler
getRevision
in class AbstractClassifier
public static void main(String[] args)
Copyright © 2012 University of Waikato, Hamilton, NZ. All Rights Reserved.