Performing analyses and visualizing results

import warnings
warnings.filterwarnings("ignore")

Now that we’ve seen how to create a connectome for an individual subject, we’re ready to think about how we can use this connectome in a machine learning analysis. We’ll keep working with the same development_dataset, but now we’d like to see if we can predict age group (i.e. whether a participant is a child or adult) based on their connectome, as defined by the functional connectivity matrix.

We’ll also explore whether we’re more or less accurate in our predictions based on how we define functional connectivity. In this example, we’ll consider three different different ways to define functional connectivity between our Multi-Subject Dictional Learning (MSDL) regions of interest (ROIs): correlation, partial correlation, and tangent space embedding.

To learn more about tangent space embedding and how it compares to standard correlations, we recommend [Dadi et al., 2019].

Load brain development fMRI dataset and MSDL atlas

First, we need to set up our minimal environment. This will include all the dependencies from the last notebook, loading the relevant data using our nilearn data set fetchers, and instantiated our NiftiMapsMasker and ConnectivityMeasure objects.

import numpy as np
import matplotlib.pyplot as plt
from nilearn import (datasets, input_data, plotting)
from nilearn.connectome import ConnectivityMeasure

development_dataset = datasets.fetch_development_fmri(n_subjects=30)
msdl_atlas = datasets.fetch_atlas_msdl()

masker = input_data.NiftiMapsMasker(
    msdl_atlas.maps, resampling_target="data",
    t_r=2, detrend=True,
    low_pass=0.1, high_pass=0.01).fit()
correlation_measure = ConnectivityMeasure(kind='correlation')
Downloading data from https://osf.io/download/5c8ff3832286e80016c3c2d1/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3842286e80017c419e0/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3854712b4001a3b5568/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb4702f39926900171090ee/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5cb46e8b353c58001c9abe98/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3872286e80017c419ea/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3872286e80017c419e9/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3884712b400183b7023/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff3884712b400193b5b5c/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff389a743a9001660a016/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff38c2286e80016c3c2da/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff38ca743a90018606dfe/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff38ca743a9001760809e/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb47056353c58001c9ac064/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb46e5af2be3c001801f799/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb4703bf2be3c001801fa49/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb46e92a3bc970019f0717f/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff38c4712b4001a3b5573/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5c8ff38da743a900176080a2/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb47016a3bc970017efe44f/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb46e43f2be3c0017056b8a/ ...
 ...done. (1 seconds, 0 min)
Downloading data from https://osf.io/download/5cb470413992690018133d8c/ ...
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_1821/2246421670.py in <module>
      4 from nilearn.connectome import ConnectivityMeasure
      5 
----> 6 development_dataset = datasets.fetch_development_fmri(n_subjects=30)
      7 msdl_atlas = datasets.fetch_atlas_msdl()
      8 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/nilearn/datasets/func.py in fetch_development_fmri(n_subjects, reduce_confounds, data_dir, resume, verbose, age_group)
   1626                                                            url=None,
   1627                                                            resume=resume,
-> 1628                                                            verbose=verbose)
   1629 
   1630     if reduce_confounds:

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/nilearn/datasets/func.py in _fetch_development_fmri_functional(participants, data_dir, url, resume, verbose)
   1498                            {'move': confounds.format(participant_id)})]
   1499         path_to_regressor = _fetch_files(data_dir, regressor_file,
-> 1500                                          verbose=verbose)[0]
   1501         regressors.append(path_to_regressor)
   1502         # Download bold images

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/nilearn/datasets/utils.py in _fetch_files(data_dir, files, resume, verbose, session)
    711             return _fetch_files(
    712                 data_dir, files, resume=resume,
--> 713                 verbose=verbose, session=session)
    714     # There are two working directories here:
    715     # - data_dir is the destination directory of the dataset

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/nilearn/datasets/utils.py in _fetch_files(data_dir, files, resume, verbose, session)
    762                                   username=opts.get('username', None),
    763                                   password=opts.get('password', None),
--> 764                                   session=session, overwrite=overwrite)
    765             if 'move' in opts:
    766                 # XXX: here, move is supposed to be a dir, it can be a name

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/nilearn/datasets/utils.py in _fetch_file(url, data_dir, resume, overwrite, md5sum, username, password, verbose, session)
    599             prepped = session.prepare_request(req)
    600             with session.send(
--> 601                     prepped, stream=True, timeout=_REQUESTS_TIMEOUT) as resp:
    602                 resp.raise_for_status()
    603                 with open(temp_full_name, "wb") as fh:

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/requests/sessions.py in send(self, request, **kwargs)
    665             # Redirect resolving generator.
    666             gen = self.resolve_redirects(r, request, **kwargs)
--> 667             history = [resp for resp in gen]
    668         else:
    669             history = []

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/requests/sessions.py in <listcomp>(.0)
    665             # Redirect resolving generator.
    666             gen = self.resolve_redirects(r, request, **kwargs)
--> 667             history = [resp for resp in gen]
    668         else:
    669             history = []

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/requests/sessions.py in resolve_redirects(self, resp, req, stream, timeout, verify, cert, proxies, yield_requests, **adapter_kwargs)
    243                     proxies=proxies,
    244                     allow_redirects=False,
--> 245                     **adapter_kwargs
    246                 )
    247 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/requests/sessions.py in send(self, request, **kwargs)
    643 
    644         # Send the request
--> 645         r = adapter.send(request, **kwargs)
    646 
    647         # Total elapsed time of the request (approximately)

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/requests/adapters.py in send(self, request, stream, timeout, verify, cert, proxies)
    448                     decode_content=False,
    449                     retries=self.max_retries,
--> 450                     timeout=timeout
    451                 )
    452 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/connectionpool.py in urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)
    708                 body=body,
    709                 headers=headers,
--> 710                 chunked=chunked,
    711             )
    712 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/connectionpool.py in _make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw)
    384         # Trigger any extra validation we need to do.
    385         try:
--> 386             self._validate_conn(conn)
    387         except (SocketTimeout, BaseSSLError) as e:
    388             # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout.

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/connectionpool.py in _validate_conn(self, conn)
   1038         # Force connect early to allow us to validate the connection.
   1039         if not getattr(conn, "sock", None):  # AppEngine might not have  `.sock`
-> 1040             conn.connect()
   1041 
   1042         if not conn.is_verified:

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/connection.py in connect(self)
    422             server_hostname=server_hostname,
    423             ssl_context=context,
--> 424             tls_in_tls=tls_in_tls,
    425         )
    426 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/util/ssl_.py in ssl_wrap_socket(sock, keyfile, certfile, cert_reqs, ca_certs, server_hostname, ssl_version, ciphers, ssl_context, ca_cert_dir, key_password, ca_cert_data, tls_in_tls)
    448     if send_sni:
    449         ssl_sock = _ssl_wrap_socket_impl(
--> 450             sock, context, tls_in_tls, server_hostname=server_hostname
    451         )
    452     else:

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/urllib3/util/ssl_.py in _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname)
    491 
    492     if server_hostname:
--> 493         return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
    494     else:
    495         return ssl_context.wrap_socket(sock)

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/ssl.py in wrap_socket(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)
    421             server_hostname=server_hostname,
    422             context=self,
--> 423             session=session
    424         )
    425 

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/ssl.py in _create(cls, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, context, session)
    868                         # non-blocking
    869                         raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
--> 870                     self.do_handshake()
    871             except (OSError, ValueError):
    872                 self.close()

/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/ssl.py in do_handshake(self, block)
   1137             if timeout == 0.0 and block:
   1138                 self.settimeout(None)
-> 1139             self._sslobj.do_handshake()
   1140         finally:
   1141             self.settimeout(timeout)

KeyboardInterrupt: 

Now we should have a much better idea what each line above is doing! Let’s see how we can use these objects across many subjects, not just the first one.

Region signals extraction

First, we can loop through the 30 participants and extract a few relevant pieces of information, including their functional scan, their confounds file, and whether they were a child or adult at the time of their scan.

Using this information, we can then transform their data using the NiftiMapsMasker we created above. As we learned last time, it’s really important to correct for known sources of noise! So we’ll also pass the relevant confounds file directly to the masker object to clean up each subject’s data.

children = []
pooled_subjects = []
groups = []  # child or adult

for func_file, confound_file, phenotypic in zip(
        development_dataset.func,
        development_dataset.confounds,
        development_dataset.phenotypic):

    time_series = masker.transform(func_file, confounds=confound_file)
    pooled_subjects.append(time_series)

    if phenotypic['Child_Adult'] == 'child':
        children.append(time_series)

    groups.append(phenotypic['Child_Adult'])

print('Data has {0} children.'.format(len(children)))

We can see that this data set has 24 children. This is roughly proportional to the original participant pool, which had 122 children and 33 adults.

We’ve also created a list in pooled_subjects containing all of the cleaned data. Remember that each entry of that list should have a shape of (168, 39). We can quickly confirm that this is true:

print(pooled_subjects[0].shape)

ROI-to-ROI correlations of children

First, we’ll use the most common kind of connectivity–and the one we used in the last section–correlation. It models the full (marginal) connectivity between pairwise ROIs.

correlation_measure expects a list of time series, so we can directly supply the list of ROI time series we just created. It will then compute individual correlation matrices for each subject. First, let’s just look at the correlation matrices for our 24 children, since we expect these matrices to be similar:

correlation_matrices = correlation_measure.fit_transform(children)

Now, all individual coefficients are stacked in a unique 2D matrix.

print('Correlations of children are stacked in an array of shape {0}'
      .format(correlation_matrices.shape))

We can also directly access the average correlation across all fitted subjects using the mean_ attribute.

mean_correlation_matrix = correlation_measure.mean_
print('Mean correlation has shape {0}.'.format(mean_correlation_matrix.shape))

Let’s display the functional connectivity matrices of the first 3 children:

_, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (matrix, ax) in enumerate(zip(correlation_matrices, axes)):
    plotting.plot_matrix(matrix, colorbar=False, axes=ax,
                         vmin=-0.8, vmax=0.8,
                         title='correlation, child {}'.format(i))

Just as before, we can also display connectome on the brain. Here, let’s show the mean connectome over all 24 children.

plotting.view_connectome(mean_correlation_matrix, msdl_atlas.region_coords,
                         edge_threshold=0.2,
                         title='mean connectome over all children')

Studying partial correlations

Rather than looking at the correlation-defined functional connectivity matrix, we can also study direct connections as revealed by partial correlation coefficients.

To do this, we can use exactly the same procedure as above, just changing the ConnectivityMeasure kind:

partial_correlation_measure = ConnectivityMeasure(kind='partial correlation')
partial_correlation_matrices = partial_correlation_measure.fit_transform(
    children)

Right away, we can see that most of direct connections are weaker than full connections for the first three children:

_, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (matrix, ax) in enumerate(zip(partial_correlation_matrices, axes)):
    plotting.plot_matrix(matrix, colorbar=False, axes=ax,
                         vmin=-0.8, vmax=0.8,
                         title='partial correlation, child {}'.format(i))

This is also visible when we display the mean partial correlation connectome:

plotting.view_connectome(
    partial_correlation_measure.mean_, msdl_atlas.region_coords,
    edge_threshold=0.2,
    title='mean partial correlation over all children')

Using tangent space embedding

An alternative method to both correlations and partial correlation is tangent space embedding. Tangent space embedding uses both correlations and partial correlations to capture reproducible connectivity patterns at the group-level.

Using this method is as easy as changing the kind of ConnectivityMeasure

tangent_measure = ConnectivityMeasure(kind='tangent')

We fit our children group and get the group connectivity matrix stored as in tangent_measure.mean_, and individual deviation matrices of each subject from it.

tangent_matrices = tangent_measure.fit_transform(children)

tangent_matrices model individual connectivities as perturbations of the group connectivity matrix tangent_measure.mean_. Keep in mind that these subjects-to-group variability matrices do not directly reflect individual brain connections. For instance negative coefficients can not be interpreted as anticorrelated regions.

_, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (matrix, ax) in enumerate(zip(tangent_matrices, axes)):
    plotting.plot_matrix(matrix, colorbar=False, axes=ax,
                         vmin=-0.8, vmax=0.8,
                         title='tangent offset, child {}'.format(i))

We don’t show the mean connectome here as average tangent matrix cannot be interpreted, since individual matrices represent deviations from the mean, which is set to 0.

Using connectivity in a classification analysis

We can use these connectivity matrices as features in a classification analysis to distinguish children from adults. This classification analysis can be implmented directly in scikit-learn, including all of the important considerations like cross-validation and measuring classification accuracy.

First, we’ll randomly split participants into training and testing sets 15 times. StratifiedShuffleSplit allows us to preserve the proportion of children-to-adults in the test set. We’ll also compute classification accuracies for each of the kinds of functional connectivity we’ve identified: correlation, partial correlation, and tangent space embedding.

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.svm import LinearSVC

kinds = ['correlation', 'partial correlation', 'tangent']
_, classes = np.unique(groups, return_inverse=True)
cv = StratifiedShuffleSplit(n_splits=15, random_state=0, test_size=5)
pooled_subjects = np.asarray(pooled_subjects)

Now, we can train the scikit-learn LinearSVC estimator to on our training set of participants and apply the trained classifier on our testing set, storing accuracy scores after each cross-validation fold:

scores = {}
for kind in kinds:
    scores[kind] = []
    for train, test in cv.split(pooled_subjects, classes):
        # *ConnectivityMeasure* can output the estimated subjects coefficients
        # as a 1D arrays through the parameter *vectorize*.
        connectivity = ConnectivityMeasure(kind=kind, vectorize=True)
        # build vectorized connectomes for subjects in the train set
        connectomes = connectivity.fit_transform(pooled_subjects[train])
        # fit the classifier
        classifier = LinearSVC().fit(connectomes, classes[train])
        # make predictions for the left-out test subjects
        predictions = classifier.predict(
            connectivity.transform(pooled_subjects[test]))
        # store the accuracy for this cross-validation fold
        scores[kind].append(accuracy_score(classes[test], predictions))

After we’ve done this for all of the folds, we can display the results!

mean_scores = [np.mean(scores[kind]) for kind in kinds]
scores_std = [np.std(scores[kind]) for kind in kinds]

plt.figure(figsize=(6, 4))
positions = np.arange(len(kinds)) * .1 + .1
plt.barh(positions, mean_scores, align='center', height=.05, xerr=scores_std)
yticks = [k.replace(' ', '\n') for k in kinds]
plt.yticks(positions, yticks)
plt.gca().grid(True)
plt.gca().set_axisbelow(True)
plt.gca().axvline(.8, color='red', linestyle='--')
plt.xlabel('Classification accuracy\n(red line = chance level)')
plt.tight_layout()

This is a small example to showcase nilearn features. In practice such comparisons need to be performed on much larger cohorts and several datasets. [Dadi et al., 2019] showed that across many cohorts and clinical questions, the tangent kind should be preferred.

Combining nilearn and scikit-learn can allow us to perform many (many) kinds of machine learning analyses, not just classification! We encourage you to explore the Examples and User Guides on the Nilearn website to learn more!

1(1,2)

Kamalaker Dadi, Mehdi Rahim, Alexandre Abraham, Darya Chyzhyk, Michael Milham, Bertrand Thirion, Gaël Varoquaux, and Alzheimer's Disease Neuroimaging Initiative. Benchmarking functional connectome-based predictive models for resting-state fMRI. Neuroimage, 192:115–134, May 2019.