Skip to content

Commit cbc2e52

Browse files
BenZickelBen Zickel
andauthored
Docker updates (#3435)
* Fix installation of pytorch from a branch in the docker container. * Add docker build trust hosts option and pytorch branch to docker image name. * Install python after SSL config change. * Update Levy-stable parameter fit test accuracy threshold from 0.03 to 0.04 (change is probably due to upgrade to pytorch 2.7.0). --------- Co-authored-by: Ben Zickel <benz@uvisionuav.com>
1 parent 4b80fc7 commit cbc2e52

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

docker/Dockerfile

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ARG uid=1000
1414
ARG gid=1000
1515
ARG ostype=Linux
1616
ARG pyro_git_url=https://github.com/pyro-ppl/pyro.git
17+
ARG trust_hosts=no
1718

1819
# Configurable settings
1920
ENV USER_NAME pyromancer
@@ -37,11 +38,21 @@ RUN bash -c 'if [ ${ostype} == Linux ]; then groupadd -r --gid ${gid} ${USER_NAM
3738
USER ${USER_NAME}
3839

3940
# Install conda
40-
RUN wget -O ~/miniconda.sh \
41+
RUN if [ ${trust_hosts} = yes ] ; then WGET_ARGS="--no-check-certificate" ; fi && \
42+
wget ${WGET_ARGS} -O ~/miniconda.sh \
4143
https://repo.anaconda.com/miniconda/Miniconda${python_version%%.*}-latest-Linux-x86_64.sh && \
4244
bash ~/miniconda.sh -f -b -p ${CONDA_DIR} && \
43-
rm ~/miniconda.sh && \
44-
conda install python=${python_version}
45+
rm ~/miniconda.sh
46+
47+
# Trust conda and pip hosts if needed
48+
RUN if [ ${trust_hosts} = yes ] ; \
49+
then \
50+
pip config set global.trusted-host "pypi.org files.pythonhosted.org download.pytorch.org" && \
51+
conda config --set ssl_verify False ; \
52+
fi
53+
54+
# Update python version
55+
RUN conda install python=${python_version}
4556

4657
# Move to home directory; and copy the install script
4758
WORKDIR ${WORK_DIR}

docker/Makefile

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_version?=3.12
2121
pytorch_branch?=release
2222
pyro_branch?=release
2323
cmd?=bash
24+
trust_hosts?="no"
2425

2526
# Determine name of docker image
2627
build run notebook: img_prefix=pyro-cpu
@@ -29,12 +30,11 @@ build run lab: img_prefix=pyro-cpu
2930
build-gpu run-gpu lab-gpu: img_prefix=pyro-gpu
3031

3132
ifeq ($(img), )
32-
IMG_NAME=${img_prefix}-${pyro_branch}-${python_version}
33+
IMG_NAME=${img_prefix}-${pyro_branch}-${pytorch_branch}-${python_version}
3334
else
3435
IMG_NAME=${img}
3536
endif
3637

37-
3838
help:
3939
@fgrep -h "##" ${MAKEFILE_LIST} | fgrep -v fgrep | sed -e 's/##//'
4040

@@ -51,6 +51,9 @@ build: ##
5151
## default - latest pytorch version on the torch python package index
5252
## pyro_branch: whether to use the released Pyro wheel or a git branch.
5353
## default - latest pyro-ppl wheel on pypi
54+
## trust_hosts: If set to yes hosts SSL ceritificates will be trusted
55+
## (might be needed when running begind a firewall)
56+
## default - Verify hosts SSL certificates
5457
##
5558
${DOCKER_CMD} build -t ${IMG_NAME} \
5659
--build-arg base_img=${BASE_IMG} \
@@ -60,7 +63,8 @@ build: ##
6063
--build-arg python_version=${python_version} \
6164
--build-arg pytorch_branch=${pytorch_branch} \
6265
--build-arg pyro_git_url=${pyro_git_url} \
63-
--build-arg pyro_branch=${pyro_branch} -f ${DOCKER_FILE} .
66+
--build-arg pyro_branch=${pyro_branch} \
67+
--build-arg trust_hosts=${trust_hosts} -f ${DOCKER_FILE} .
6468

6569
build-gpu: ##
6670
## Build a docker image for running Pyro on a GPU.
@@ -71,6 +75,9 @@ build-gpu: ##
7175
## default - latest pytorch version on the torch python package index
7276
## pyro_branch: whether to use the released Pyro wheel or a git branch.
7377
## default - latest pyro-ppl wheel on pypi
78+
## trust_hosts: If set to yes hosts SSL ceritificates will be trusted
79+
## (might be needed when running begind a firewall)
80+
## default - Verify hosts SSL certificates
7481
##
7582
${DOCKER_CMD} build -t ${IMG_NAME} \
7683
--build-arg base_img=${BASE_CUDA_IMG} \
@@ -81,7 +88,8 @@ build-gpu: ##
8188
--build-arg python_version=${python_version} \
8289
--build-arg pytorch_branch=${pytorch_branch} \
8390
--build-arg pyro_git_url=${pyro_git_url} \
84-
--build-arg pyro_branch=${pyro_branch} -f ${DOCKER_FILE} .
91+
--build-arg pyro_branch=${pyro_branch} \
92+
--build-arg trust_hosts=${trust_hosts} -f ${DOCKER_FILE} .
8593

8694
create-host-workspace: ##
8795
## Create shared volume on the host for sharing files with the container.

docker/install.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ if [ ${pytorch_branch} != "release" ]
1818
then
1919
git clone --recursive https://github.com/pytorch/pytorch.git
2020
pushd pytorch && git checkout ${pytorch_branch}
21-
pip uninstall torch
21+
pip uninstall -y torch
22+
conda install cmake ninja
23+
pip install -r requirements.txt
24+
pip install mkl-static mkl-include
25+
if [ ${pytorch_whl} != "cpu" ]
26+
then
27+
conda install -c pytorch magma-cuda${pytorch_whl:2}
28+
fi
2229
pip install -e .
2330
popd
2431
fi

tests/distributions/test_stable_log_prob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def log_progress():
129129
train(model, guide)
130130

131131
# Verify fit accuracy
132-
assert_close(alpha, pyro.param("alpha").item(), atol=0.03)
132+
assert_close(alpha, pyro.param("alpha").item(), atol=0.04)
133133
assert_close(beta, pyro.param("beta").item(), atol=0.06)
134134
assert_close(c, pyro.param("c").item(), atol=0.2)
135135
assert_close(mu, pyro.param("mu").item(), atol=0.2)

0 commit comments

Comments
 (0)